diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000000..fd577eb919 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,33 @@ +# -- repository yaml -- + +# Explicitly wait for all jobs to finish, as wait_for_ci prematurely triggers. +# See https://github.com/python-trio/trio/issues/2689 +codecov: + notify: + # This number needs to be changed whenever the number of runs in CI is changed. + # Another option is codecov-cli: https://github.com/codecov/codecov-cli#send-notifications + after_n_builds: 31 + wait_for_ci: false + notify_error: true # if uploads fail, replace cov comment with a comment with errors. + require_ci_to_pass: false + + # Publicly exposing the token has some small risks from mistakes or malicious actors. + # See https://docs.codecov.com/docs/codecov-tokens for correctly configuring it. + token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 + +# only post PR comment if coverage changes +comment: + require_changes: true + +coverage: + # required range + precision: 5 + round: down + range: 100..100 + status: + project: + default: + target: 100% + patch: + default: + target: 100% # require patches to be 100% diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 1d3079ad5a..001d83c4a1 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,2 +1,4 @@ # sorting all imports with isort 933f77b96f0092e1baab4474a9208fc2e379aa32 +# enabling ruff's flake8-commas rule +b25c02a94e2defcb0fad32976b02218be1133bdf diff --git a/.github/workflows/autodeps.yml b/.github/workflows/autodeps.yml index 0e0655c5aa..1182d38782 100644 --- a/.github/workflows/autodeps.yml +++ b/.github/workflows/autodeps.yml @@ -17,26 +17,35 @@ jobs: issues: write repository-projects: write contents: write + steps: - name: Checkout uses: actions/checkout@v4 - name: Setup python uses: actions/setup-python@v5 with: - python-version: "3.8" + python-version: "3.9" + - name: Bump dependencies run: | python -m pip install -U pip pre-commit python -m pip install -r test-requirements.txt - uv pip compile --universal --python-version=3.8 --upgrade test-requirements.in -o test-requirements.txt - uv pip compile --universal --python-version=3.8 --upgrade docs-requirements.in -o docs-requirements.txt + uv pip compile --universal --python-version=3.9 --upgrade test-requirements.in -o test-requirements.txt + uv pip compile --universal --python-version=3.11 --upgrade docs-requirements.in -o docs-requirements.txt pre-commit autoupdate --jobs 0 + + - name: Install new requirements + run: python -m pip install -r test-requirements.txt + + # apply newer versions' formatting - name: Black + run: black src/trio + + - name: uv run: | - # The new dependencies may contain a new black version. - # Commit any changes immediately. - python -m pip install -r test-requirements.txt - black src/trio + uv pip compile --universal --python-version=3.9 test-requirements.in -o test-requirements.txt + uv pip compile --universal --python-version=3.11 docs-requirements.in -o docs-requirements.txt + - name: Commit changes and create automerge PR env: GH_TOKEN: ${{ github.token }} diff --git a/.github/workflows/check-newsfragment.yml b/.github/workflows/check-newsfragment.yml new file mode 100644 index 0000000000..0aa78fcd3c --- /dev/null +++ b/.github/workflows/check-newsfragment.yml @@ -0,0 +1,23 @@ +name: Check newsfragment + +on: + pull_request: + types: [labeled, unlabeled, opened, synchronize] + branches: + - main + +jobs: + check-newsfragment: + if: ${{ !contains(github.event.pull_request.labels.*.name, 'skip newsfragment') }} + runs-on: 'ubuntu-latest' + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Check newsfragments + run: | + if git diff --name-only origin/main | grep -v '/_tests/' | grep 'src/trio/'; then + git diff --name-only origin/main | grep 'newsfragments/' || exit 1 + fi diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 60f8b79c03..7a30132613 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,38 +3,171 @@ name: CI on: push: branches-ignore: - - "dependabot/**" + # these branches always have another event associated + - gh-readonly-queue/** # GitHub's merge queue uses `merge_group` + - autodeps/** # autodeps always makes a PR + - pre-commit-ci-update-config # pre-commit.ci's updates always have a PR pull_request: + merge_group: concurrency: group: ${{ github.ref }}-${{ github.workflow }}-${{ github.event_name }}${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) && format('-{0}', github.sha) || '' }} cancel-in-progress: true +env: + dists-artifact-name: python-package-distributions + dist-name: trio + jobs: + build: + name: ๐Ÿ‘ท dists + + runs-on: ubuntu-latest + + outputs: + dist-version: ${{ steps.dist-version.outputs.version }} + sdist-artifact-name: ${{ steps.artifact-name.outputs.sdist }} + wheel-artifact-name: ${{ steps.artifact-name.outputs.wheel }} + + steps: + - name: Switch to using Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: 3.11 + + - name: Grab the source from Git + uses: actions/checkout@v4 + + - name: Get the dist version + id: dist-version + run: >- + echo "version=$( + grep ^__version__ src/trio/_version.py + | sed 's#__version__ = "\([^"]\+\)"#\1#' + )" + >> "${GITHUB_OUTPUT}" + + - name: Set the expected dist artifact names + id: artifact-name + run: | + echo 'sdist=${{ env.dist-name }}-*.tar.gz' >> "${GITHUB_OUTPUT}" + echo 'wheel=${{ + env.dist-name + }}-*-py3-none-any.whl' >> "${GITHUB_OUTPUT}" + + - name: Install build + run: python -Im pip install build + + - name: Build dists + run: python -Im build + - name: Verify that the artifacts with expected names got created + run: >- + ls -1 + dist/${{ steps.artifact-name.outputs.sdist }} + dist/${{ steps.artifact-name.outputs.wheel }} + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: ${{ env.dists-artifact-name }} + # NOTE: Exact expected file names are specified here + # NOTE: as a safety measure โ€” if anything weird ends + # NOTE: up being in this dir or not all dists will be + # NOTE: produced, this will fail the workflow. + path: | + dist/${{ steps.artifact-name.outputs.sdist }} + dist/${{ steps.artifact-name.outputs.wheel }} + retention-days: 5 + + - name: >- + Smoke-test: + retrieve the project source from an sdist inside the GHA artifact + uses: re-actors/checkout-python-sdist@release/v2 + with: + source-tarball-name: ${{ steps.artifact-name.outputs.sdist }} + workflow-artifact-name: ${{ env.dists-artifact-name }} + + - name: >- + Smoke-test: move the sdist-retrieved dir into sdist-src + run: | + mv -v '${{ github.workspace }}' '${{ runner.temp }}/sdist-src' + mkdir -pv '${{ github.workspace }}' + mv -v '${{ runner.temp }}/sdist-src' '${{ github.workspace }}/sdist-src' + shell: bash -eEuo pipefail {0} + + - name: >- + Smoke-test: grab the source from Git into git-src + uses: actions/checkout@v4 + with: + path: git-src + + - name: >- + Smoke-test: install test requirements from the Git repo + run: >- + python -Im + pip install -c test-requirements.txt -r test-requirements.txt + shell: bash -eEuo pipefail {0} + working-directory: git-src + + - name: >- + Smoke-test: collect tests from the Git repo + env: + PYTHONPATH: src/ + run: >- + pytest --collect-only -qq . + | sort + | tee collected-tests + shell: bash -eEuo pipefail {0} + working-directory: git-src + + - name: >- + Smoke-test: collect tests from the sdist tarball + env: + PYTHONPATH: src/ + run: >- + pytest --collect-only -qq . + | sort + | tee collected-tests + shell: bash -eEuo pipefail {0} + working-directory: sdist-src + + - name: >- + Smoke-test: + verify that all the tests from Git are included in the sdist + run: diff --unified sdist-src/collected-tests git-src/collected-tests + shell: bash -eEuo pipefail {0} + Windows: name: 'Windows (${{ matrix.python }}, ${{ matrix.arch }}${{ matrix.extra_name }})' + needs: + - build + timeout-minutes: 20 runs-on: 'windows-latest' strategy: fail-fast: false matrix: - python: ['3.8', '3.9', '3.10', '3.11', '3.12'] + python: ['3.9', '3.10', '3.11', '3.12', '3.13'] arch: ['x86', 'x64'] lsp: [''] lsp_extract_file: [''] extra_name: [''] include: - - python: '3.8' + - python: '3.9' arch: 'x64' lsp: 'https://raw.githubusercontent.com/python-trio/trio-ci-assets/master/komodia-based-vpn-setup.zip' lsp_extract_file: 'komodia-based-vpn-setup.exe' extra_name: ', with Komodia LSP' - - python: '3.8' + - python: '3.9' arch: 'x64' lsp: 'https://www.proxifier.com/download/legacy/ProxifierSetup342.exe' lsp_extract_file: '' extra_name: ', with IFS LSP' - #- python: '3.8' + - python: 'pypy-3.10' + arch: 'x64' + lsp: '' + lsp_extract_file: '' + extra_name: '' + #- python: '3.9' # arch: 'x64' # lsp: 'http://download.pctools.com/mirror/updates/9.0.0.2308-SDavfree-lite_en.exe' # lsp_extract_file: '' @@ -49,8 +182,11 @@ jobs: || false }} steps: - - name: Checkout - uses: actions/checkout@v4 + - name: Retrieve the project source from an sdist inside the GHA artifact + uses: re-actors/checkout-python-sdist@release/v2 + with: + source-tarball-name: ${{ needs.build.outputs.sdist-artifact-name }} + workflow-artifact-name: ${{ env.dists-artifact-name }} - name: Setup python uses: actions/setup-python@v5 with: @@ -76,27 +212,33 @@ jobs: uses: codecov/codecov-action@v3 with: directory: empty - token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 name: Windows (${{ matrix.python }}, ${{ matrix.arch }}${{ matrix.extra_name }}) + # multiple flags is marked as an error in codecov UI, but is actually fine + # https://github.com/codecov/feedback/issues/567 flags: Windows,${{ matrix.python }} + # this option cannot be set in .codecov.yml + fail_ci_if_error: true Ubuntu: name: 'Ubuntu (${{ matrix.python }}${{ matrix.extra_name }})' + needs: + - build + timeout-minutes: 10 runs-on: 'ubuntu-latest' strategy: fail-fast: false matrix: - python: ['pypy-3.9', 'pypy-3.10', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] + python: ['pypy-3.10', '3.9', '3.10', '3.11', '3.12', '3.13'] check_formatting: ['0'] no_test_requirements: ['0'] extra_name: [''] include: - - python: '3.8' + - python: '3.13' check_formatting: '1' extra_name: ', check formatting' # separate test run that doesn't install test-requirements.txt - - python: '3.8' + - python: '3.9' no_test_requirements: '1' extra_name: ', no test-requirements' continue-on-error: >- @@ -109,41 +251,47 @@ jobs: || false }} steps: - - name: Checkout + - name: Retrieve the project source from an sdist inside the GHA artifact + if: matrix.check_formatting != '1' + uses: re-actors/checkout-python-sdist@release/v2 + with: + source-tarball-name: ${{ needs.build.outputs.sdist-artifact-name }} + workflow-artifact-name: ${{ env.dists-artifact-name }} + - name: Grab the source from Git + if: matrix.check_formatting == '1' uses: actions/checkout@v4 - name: Setup python uses: actions/setup-python@v5 - if: "!endsWith(matrix.python, '-dev')" with: python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }} cache: pip cache-dependency-path: test-requirements.txt - - name: Setup python (dev) - uses: deadsnakes/action@v2.0.2 - if: endsWith(matrix.python, '-dev') - with: - python-version: '${{ matrix.python }}' - name: Run tests run: ./ci.sh env: CHECK_FORMATTING: '${{ matrix.check_formatting }}' NO_TEST_REQUIREMENTS: '${{ matrix.no_test_requirements }}' - - if: always() + - if: >- + always() + && matrix.check_formatting != '1' uses: codecov/codecov-action@v3 with: directory: empty - token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 name: Ubuntu (${{ matrix.python }}${{ matrix.extra_name }}) flags: Ubuntu,${{ matrix.python }} + fail_ci_if_error: true macOS: name: 'macOS (${{ matrix.python }})' + needs: + - build + timeout-minutes: 15 runs-on: 'macos-latest' strategy: fail-fast: false matrix: - python: ['3.8', '3.9', '3.10', '3.11', '3.12'] + python: ['pypy-3.10', '3.9', '3.10', '3.11', '3.12', '3.13'] continue-on-error: >- ${{ ( @@ -154,8 +302,11 @@ jobs: || false }} steps: - - name: Checkout - uses: actions/checkout@v4 + - name: Retrieve the project source from an sdist inside the GHA artifact + uses: re-actors/checkout-python-sdist@release/v2 + with: + source-tarball-name: ${{ needs.build.outputs.sdist-artifact-name }} + workflow-artifact-name: ${{ env.dists-artifact-name }} - name: Setup python uses: actions/setup-python@v5 with: @@ -168,45 +319,70 @@ jobs: uses: codecov/codecov-action@v3 with: directory: empty - token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 name: macOS (${{ matrix.python }}) flags: macOS,${{ matrix.python }} + fail_ci_if_error: true # run CI on a musl linux Alpine: name: "Alpine" + needs: + - build + runs-on: ubuntu-latest container: alpine steps: - - name: Checkout - uses: actions/checkout@v4 - name: Install necessary packages # can't use setup-python because that python doesn't seem to work; # `python3-dev` (rather than `python:alpine`) for some ctypes reason, # `nodejs` for pyright (`node-env` pulls in nodejs but that takes a while and can time out the test). - run: apk update && apk add python3-dev bash nodejs + # `perl` for a platform independent `sed -i` alternative + run: apk update && apk add python3-dev bash nodejs perl + - name: Retrieve the project source from an sdist inside the GHA artifact + # must be after `apk add` because it relies on `bash` existing + uses: re-actors/checkout-python-sdist@release/v2 + with: + source-tarball-name: ${{ needs.build.outputs.sdist-artifact-name }} + workflow-artifact-name: ${{ env.dists-artifact-name }} - name: Enter virtual environment run: python -m venv .venv - name: Run tests run: source .venv/bin/activate && ./ci.sh + - name: Get Python version for codecov flag + id: get-version + run: echo "version=$(python -V | cut -d' ' -f2 | cut -d'.' -f1,2)" >> "${GITHUB_OUTPUT}" - if: always() uses: codecov/codecov-action@v3 with: directory: empty - token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 name: Alpine - flags: Alpine,3.12 + flags: Alpine,${{ steps.get-version.outputs.version }} + fail_ci_if_error: true Cython: name: "Cython" + needs: + - build + runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python: ['3.8', '3.12'] + include: + - python: '3.9' # We support running on cython 2 and 3 for 3.9 + cython: '<3' # cython 2 + - python: '3.9' + cython: '>=3' # cython 3 (or greater) + - python: '3.11' # 3.11 is the last version Cy2 supports + cython: '<3' # cython 2 + - python: '3.13' # We support running cython3 on 3.13 + cython: '>=3' # cython 3 (or greater) steps: - - name: Checkout - uses: actions/checkout@v4 + - name: Retrieve the project source from an sdist inside the GHA artifact + uses: re-actors/checkout-python-sdist@release/v2 + with: + source-tarball-name: ${{ needs.build.outputs.sdist-artifact-name }} + workflow-artifact-name: ${{ env.dists-artifact-name }} - name: Setup python uses: actions/setup-python@v5 with: @@ -214,22 +390,34 @@ jobs: cache: pip # setuptools is needed to get distutils on 3.12, which cythonize requires - name: install trio and setuptools - run: python -m pip install --upgrade pip . setuptools + run: python -m pip install --upgrade pip . setuptools 'coverage[toml]' - - name: install cython<3 - run: python -m pip install "cython<3" - - name: compile pyx file - run: cythonize -i tests/cython/test_cython.pyx - - name: import & run module - run: python -c 'import tests.cython.test_cython' + - name: add cython plugin to the coveragepy config + run: >- + sed -i 's#plugins\s=\s\[\]#plugins = ["Cython.Coverage"]#' + pyproject.toml + + - name: install cython & compile pyx file + env: + CFLAGS: ${{ env.CFLAGS }} -DCYTHON_TRACE_NOGIL=1 + run: | + python -m pip install "cython${{ matrix.cython }}" + cythonize --inplace -X linetrace=True tests/cython/test_cython.pyx - - name: install cython>=3 - run: python -m pip install "cython>=3" - - name: compile pyx file - # different cython version should trigger a re-compile, but --force just in case - run: cythonize --inplace --force tests/cython/test_cython.pyx - name: import & run module - run: python -c 'import tests.cython.test_cython' + run: coverage run -m tests.cython.run_test_cython + + - name: get Python version for codecov flag + id: get-version + run: >- + echo "version=$(python -V | cut -d' ' -f2 | cut -d'.' -f1,2)" + >> "${GITHUB_OUTPUT}" + - if: always() + uses: codecov/codecov-action@v5 + with: + name: Cython + flags: Cython,${{ steps.get-version.outputs.version }} + fail_ci_if_error: true # https://github.com/marketplace/actions/alls-green#why check: # This job does nothing and is only used for the branch protection diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000000..a40da4c3af --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,44 @@ +on: + push: + tags: + - v* + +# a lot of code taken from https://github.com/pypa/cibuildwheel/blob/main/examples/github-deploy.yml +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.9" + - run: python -m pip install build + - run: python -m build + + - uses: actions/upload-artifact@v4 + with: + name: trio-dist + path: | + dist/*.tar.gz + dist/*.whl + + pypi-publish: + needs: [build] + name: upload release to PyPI + runs-on: ubuntu-latest + environment: + name: release + url: https://pypi.org/project/trio + permissions: + id-token: write + + steps: + - uses: actions/download-artifact@v4 + with: + pattern: trio-* + path: dist + merge-multiple: true + + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index 388f2dfbda..cad76cb460 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# In case somebody wants to restore the directory for local testing +notes-to-self/ + # Project-specific generated files docs/build/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 810c28e4a2..415a6d2364 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,13 +1,12 @@ ci: - autofix_commit_msg: "[pre-commit.ci] auto fixes from pre-commit.com hooks" - autofix_prs: false - autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" + autofix_prs: true autoupdate_schedule: weekly submodules: false + skip: [regenerate-files] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -19,11 +18,11 @@ repos: - id: sort-simple-yaml files: .pre-commit-config.yaml - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.4.2 + rev: 24.10.0 hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.5 + rev: v0.8.3 hooks: - id: ruff types: [file] @@ -33,3 +32,15 @@ repos: rev: v2.3.0 hooks: - id: codespell + - repo: https://github.com/sphinx-contrib/sphinx-lint + rev: v1.0.0 + hooks: + - id: sphinx-lint + - repo: local + hooks: + - id: regenerate-files + name: regenerate generated files + language: system + entry: python src/trio/_tools/gen_exports.py + pass_filenames: false + files: ^src\/trio\/_core\/(_run|(_i(o_(common|epoll|kqueue|windows)|nstrumentation)))\.py$ diff --git a/MANIFEST.in b/MANIFEST.in index 440994e43a..5ab28eabbd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,8 +1,14 @@ +include .codecov.yml +include check.sh +include ci.sh include LICENSE LICENSE.MIT LICENSE.APACHE2 include README.rst include CODE_OF_CONDUCT.md CONTRIBUTING.md -include test-requirements.txt +include *-requirements.in +include *-requirements.txt include src/trio/py.typed +include src/trio/_tests/astrill-codesigning-cert.cer recursive-include src/trio/_tests/test_ssl_certs *.pem recursive-include docs * +recursive-include tests * prune docs/build diff --git a/README.rst b/README.rst index 65f6df8946..e3620546a0 100644 --- a/README.rst +++ b/README.rst @@ -92,7 +92,7 @@ demonstration of implementing the "Happy Eyeballs" algorithm in an older library versus Trio. **Cool, but will it work on my system?** Probably! As long as you have -some kind of Python 3.8-or-better (CPython or `currently maintained versions of +some kind of Python 3.9-or-better (CPython or `currently maintained versions of PyPy3 `__ are both fine), and are using Linux, macOS, Windows, or FreeBSD, then Trio will work. Other environments might work too, but those diff --git a/check.sh b/check.sh index 85e1227b16..a1efa66a20 100755 --- a/check.sh +++ b/check.sh @@ -78,10 +78,10 @@ fi # Check pip compile is consistent echo "::group::Pip Compile - Tests" -uv pip compile --universal --python-version=3.8 test-requirements.in -o test-requirements.txt +uv pip compile --universal --python-version=3.9 test-requirements.in -o test-requirements.txt echo "::endgroup::" echo "::group::Pip Compile - Docs" -uv pip compile --universal --python-version=3.8 docs-requirements.in -o docs-requirements.txt +uv pip compile --universal --python-version=3.11 docs-requirements.in -o docs-requirements.txt echo "::endgroup::" if git status --porcelain | grep -q "requirements.txt"; then @@ -112,7 +112,7 @@ if [ $EXIT_STATUS -ne 0 ]; then Problems were found by static analysis (listed above). To fix formatting and see remaining errors, run - pip install -r test-requirements.txt + uv pip install -r test-requirements.txt black src/trio ruff check src/trio ./check.sh diff --git a/ci.sh b/ci.sh index 112ed04d7a..83ec65748b 100755 --- a/ci.sh +++ b/ci.sh @@ -37,25 +37,29 @@ python -c "import sys, struct, ssl; print('python:', sys.version); print('versio echo "::endgroup::" echo "::group::Install dependencies" -python -m pip install -U pip build +python -m pip install -U pip uv -c test-requirements.txt python -m pip --version +python -m uv --version + +python -m uv pip install build python -m build -python -m pip install dist/*.whl +wheel_package=$(ls dist/*.whl) +python -m uv pip install "trio @ $wheel_package" -c test-requirements.txt if [ "$CHECK_FORMATTING" = "1" ]; then - python -m pip install -r test-requirements.txt + python -m uv pip install -r test-requirements.txt exceptiongroup echo "::endgroup::" source check.sh else # Actual tests # expands to 0 != 1 if NO_TEST_REQUIREMENTS is not set, if set the `-0` has no effect # https://pubs.opengroup.org/onlinepubs/9699919799/utilities/V3_chap02.html#tag_18_06_02 - if [ ${NO_TEST_REQUIREMENTS-0} == 1 ]; then - python -m pip install pytest coverage + if [ "${NO_TEST_REQUIREMENTS-0}" == 1 ]; then + python -m uv pip install pytest coverage -c test-requirements.txt flags="--skip-optional-imports" else - python -m pip install -r test-requirements.txt + python -m uv pip install -r test-requirements.txt flags="" fi @@ -112,13 +116,13 @@ else echo "::group::Setup for tests" # We run the tests from inside an empty directory, to make sure Python - # doesn't pick up any .py files from our working dir. Might have been - # pre-created by some of the code above. + # doesn't pick up any .py files from our working dir. Might have already + # been created by a previous run. mkdir empty || true cd empty INSTALLDIR=$(python -c "import os, trio; print(os.path.dirname(trio.__file__))") - cp ../pyproject.toml $INSTALLDIR + cp ../pyproject.toml "$INSTALLDIR" # TODO: remove this # get mypy tests a nice cache MYPYPATH=".." mypy --config-file= --cache-dir=./.mypy_cache -c "import trio" >/dev/null 2>/dev/null || true @@ -126,9 +130,15 @@ else # support subprocess spawning with coverage.py echo "import coverage; coverage.process_startup()" | tee -a "$INSTALLDIR/../sitecustomize.py" + perl -i -pe 's/-p trio\._tests\.pytest_plugin//' "$INSTALLDIR/pyproject.toml" + echo "::endgroup::" echo "::group:: Run Tests" - if COVERAGE_PROCESS_START=$(pwd)/../pyproject.toml coverage run --rcfile=../pyproject.toml -m pytest -ra --junitxml=../test-results.xml --run-slow ${INSTALLDIR} --verbose --durations=10 $flags; then + if PYTHONPATH=../tests COVERAGE_PROCESS_START=$(pwd)/../pyproject.toml \ + coverage run --rcfile=../pyproject.toml -m \ + pytest -ra --junitxml=../test-results.xml \ + -p _trio_check_attrs_aliases --verbose --durations=10 \ + -p trio._tests.pytest_plugin --run-slow $flags "${INSTALLDIR}"; then PASSED=true else PASSED=false diff --git a/docs-requirements.in b/docs-requirements.in index c4695fc688..1a571d9832 100644 --- a/docs-requirements.in +++ b/docs-requirements.in @@ -2,7 +2,8 @@ # sphinx 5.3 doesn't work with our _NoValue workaround sphinx >= 6.0 jinja2 -sphinx_rtd_theme +# >= is necessary to prevent `uv` from selecting a `Sphinx` version this does not support +sphinx_rtd_theme >= 3 sphinxcontrib-jquery sphinxcontrib-trio towncrier diff --git a/docs-requirements.txt b/docs-requirements.txt index 461d6e3d93..03cefbc9a4 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -1,70 +1,62 @@ # This file was autogenerated by uv via the following command: -# uv pip compile --universal --python-version=3.8 docs-requirements.in -o docs-requirements.txt -alabaster==0.7.13 +# uv pip compile --universal --python-version=3.11 docs-requirements.in -o docs-requirements.txt +alabaster==1.0.0 # via sphinx -attrs==23.2.0 +attrs==24.2.0 # via # -r docs-requirements.in # outcome -babel==2.15.0 +babel==2.16.0 # via sphinx beautifulsoup4==4.12.3 # via sphinx-codeautolink -certifi==2024.7.4 +certifi==2024.8.30 # via requests -cffi==1.16.0 ; os_name == 'nt' or platform_python_implementation != 'PyPy' +cffi==1.17.1 ; platform_python_implementation != 'PyPy' or os_name == 'nt' # via # -r docs-requirements.in # cryptography -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 # via requests click==8.1.7 # via towncrier -colorama==0.4.6 ; platform_system == 'Windows' or sys_platform == 'win32' +colorama==0.4.6 ; sys_platform == 'win32' or platform_system == 'Windows' # via # click # sphinx -cryptography==42.0.8 +cryptography==44.0.0 # via pyopenssl -docutils==0.20.1 +docutils==0.21.2 # via # sphinx # sphinx-rtd-theme -exceptiongroup==1.2.1 +exceptiongroup==1.2.2 # via -r docs-requirements.in -idna==3.7 +idna==3.10 # via # -r docs-requirements.in # requests imagesize==1.4.1 # via sphinx -immutables==0.20 +immutables==0.21 # via -r docs-requirements.in -importlib-metadata==8.0.0 ; python_version < '3.10' - # via sphinx -importlib-resources==6.4.0 ; python_version < '3.10' - # via towncrier -incremental==22.10.0 - # via towncrier jinja2==3.1.4 # via # -r docs-requirements.in # sphinx # towncrier -markupsafe==2.1.5 +markupsafe==3.0.2 # via jinja2 outcome==1.3.0.post0 # via -r docs-requirements.in -packaging==24.1 +packaging==24.2 # via sphinx -pycparser==2.22 ; os_name == 'nt' or platform_python_implementation != 'PyPy' +pycparser==2.22 ; platform_python_implementation != 'PyPy' or os_name == 'nt' # via cffi pygments==2.18.0 # via sphinx -pyopenssl==24.1.0 +pyopenssl==24.3.0 # via -r docs-requirements.in -pytz==2024.1 ; python_version < '3.9' - # via babel requests==2.32.3 # via sphinx sniffio==1.3.1 @@ -73,9 +65,9 @@ snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 # via -r docs-requirements.in -soupsieve==2.5 +soupsieve==2.6 # via beautifulsoup4 -sphinx==7.1.2 +sphinx==8.1.3 # via # -r docs-requirements.in # sphinx-codeautolink @@ -85,15 +77,15 @@ sphinx==7.1.2 # sphinxcontrib-trio sphinx-codeautolink==0.15.2 # via -r docs-requirements.in -sphinx-hoverxref==1.4.0 +sphinx-hoverxref==1.4.2 # via -r docs-requirements.in -sphinx-rtd-theme==2.0.0 +sphinx-rtd-theme==3.0.2 # via -r docs-requirements.in -sphinxcontrib-applehelp==1.0.4 +sphinxcontrib-applehelp==2.0.0 # via sphinx -sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-devhelp==2.0.0 # via sphinx -sphinxcontrib-htmlhelp==2.0.1 +sphinxcontrib-htmlhelp==2.1.0 # via sphinx sphinxcontrib-jquery==4.1 # via @@ -102,19 +94,13 @@ sphinxcontrib-jquery==4.1 # sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==1.0.3 +sphinxcontrib-qthelp==2.0.0 # via sphinx -sphinxcontrib-serializinghtml==1.1.5 +sphinxcontrib-serializinghtml==2.0.0 # via sphinx sphinxcontrib-trio==1.1.2 # via -r docs-requirements.in -tomli==2.0.1 ; python_version < '3.11' - # via towncrier -towncrier==23.11.0 +towncrier==24.8.0 # via -r docs-requirements.in -urllib3==2.2.2 +urllib3==2.2.3 # via requests -zipp==3.19.2 ; python_version < '3.10' - # via - # importlib-metadata - # importlib-resources diff --git a/docs/source/awesome-trio-libraries.rst b/docs/source/awesome-trio-libraries.rst index 823bf0779a..875471a21f 100644 --- a/docs/source/awesome-trio-libraries.rst +++ b/docs/source/awesome-trio-libraries.rst @@ -29,8 +29,8 @@ Web and HTML ------------ * `httpx `__ - HTTPX is a fully featured HTTP client for Python 3, which provides sync and async APIs, and support for both HTTP/1.1 and HTTP/2. * `trio-websocket `__ - A WebSocket client and server implementation striving for safety, correctness, and ergonomics. -* `quart-trio `__ - Like Flask, but for Trio. A simple and powerful framework for building async web applications and REST APIs. Tip: this is an ASGI-based framework, so you'll also need an HTTP server with ASGI support. -* `hypercorn `__ - An HTTP server for hosting your ASGI apps. Supports HTTP/1.1, HTTP/2, HTTP/3, and Websockets. Can be run as a standalone server, or embedded in a larger Trio app. Use it with ``quart-trio``, or any other Trio-compatible ASGI framework. +* `quart-trio `__ - Like Flask, but for Trio. A simple and powerful framework for building async web applications and REST APIs. Tip: this is an ASGI-based framework, so you'll also need an HTTP server with ASGI support. +* `hypercorn `__ - An HTTP server for hosting your ASGI apps. Supports HTTP/1.1, HTTP/2, HTTP/3, and Websockets. Can be run as a standalone server, or embedded in a larger Trio app. Use it with ``quart-trio``, or any other Trio-compatible ASGI framework. * `DeFramed `__ - DeFramed is a Web non-framework that supports a 99%-server-centric approach to Web coding, including support for the `Remi `__ GUI library. * `pura `__ - A simple web framework for embedding realtime graphical visualization into Trio apps, enabling inspection and manipulation of program state during development. * `pyscalpel `__ - A fast and powerful webscraping library. @@ -108,6 +108,8 @@ Tools and Utilities * `aiometer `__ - Execute lots of tasks concurrently while controlling concurrency limits * `triotp `__ - OTP framework for Python Trio * `aioresult `__ - Get the return value of a background async function in Trio or anyio, along with a simple Future class and wait utilities +* `aiologic `__ - Thread-safe synchronization and communication primitives: locks, capacity limiters, queues, etc. +* `culsans `__ - Janus-like sync-async queue with Trio support. Unlike aiologic queues, provides API compatible interfaces. Trio/Asyncio Interoperability diff --git a/docs/source/conf.py b/docs/source/conf.py index 7ea27de24b..fb8e60cdc5 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,9 +19,10 @@ from __future__ import annotations import collections.abc +import glob import os import sys -import types +from pathlib import Path from typing import TYPE_CHECKING, cast if TYPE_CHECKING: @@ -36,19 +37,62 @@ # Enable reloading with `typing.TYPE_CHECKING` being True os.environ["SPHINX_AUTODOC_RELOAD_MODULES"] = "1" -# https://docs.readthedocs.io/en/stable/builds.html#build-environment -if "READTHEDOCS" in os.environ: - import glob - - if glob.glob("../../newsfragments/*.*.rst"): - print("-- Found newsfragments; running towncrier --", flush=True) - import subprocess - +# Handle writing newsfragments into the history file. +# We want to keep files unchanged when testing locally. +# So immediately revert the contents after running towncrier, +# then substitute when Sphinx wants to read it in. +history_file = Path("history.rst") + +history_new: str | None +if glob.glob("../../newsfragments/*.*.rst"): + print("-- Found newsfragments; running towncrier --", flush=True) + history_orig = history_file.read_bytes() + import subprocess + + # In case changes were staged, preserve indexed version. + # This grabs the hash of the current staged version. + history_staged = subprocess.run( + ["git", "rev-parse", "--verify", ":docs/source/history.rst"], + check=True, + cwd="../..", + stdout=subprocess.PIPE, + encoding="ascii", + ).stdout.strip() + try: subprocess.run( - ["towncrier", "--yes", "--date", "not released yet"], + ["towncrier", "--keep", "--date", "not released yet"], cwd="../..", check=True, ) + history_new = history_file.read_text("utf8") + finally: + # Make sure this reverts even if a failure occurred. + # Restore whatever was staged. + print(f"Restoring history.rst = {history_staged}") + subprocess.run( + [ + "git", + "update-index", + "--cacheinfo", + f"100644,{history_staged},docs/source/history.rst", + ], + cwd="../..", + check=False, + ) + # And restore the working copy. + history_file.write_bytes(history_orig) + del history_orig # We don't need this any more. +else: + # Leave it as is. + history_new = None + + +def on_read_source(app: Sphinx, docname: str, content: list[str]) -> None: + """Substitute the modified history file.""" + if docname == "history" and history_new is not None: + # This is a 1-item list with the file contents. + content[0] = history_new + # Sphinx is very finicky, and somewhat buggy, so we have several different # methods to help it resolve links. @@ -107,16 +151,6 @@ def autodoc_process_signature( return_annotation: str, ) -> tuple[str, str]: """Modify found signatures to fix various issues.""" - if name == "trio.testing._raises_group._ExceptionInfo.type": - # This has the type "type[E]", which gets resolved into the property itself. - # That means Sphinx can't resolve it. Fix the issue by overwriting with a fully-qualified - # name. - assert isinstance(obj, property), obj - assert isinstance(obj.fget, types.FunctionType), obj.fget - assert ( - obj.fget.__annotations__["return"] == "type[MatchE]" - ), obj.fget.__annotations__ - obj.fget.__annotations__["return"] = "type[~trio.testing._raises_group.MatchE]" if signature is not None: signature = signature.replace("~_contextvars.Context", "~contextvars.Context") if name == "trio.lowlevel.RunVar": # Typevar is not useful here. @@ -125,15 +159,6 @@ def autodoc_process_signature( # Strip the type from the union, make it look like = ... signature = signature.replace(" | type[trio._core._local._NoValue]", "") signature = signature.replace("", "...") - if name in ("trio.testing.RaisesGroup", "trio.testing.Matcher") and ( - "+E" in signature or "+MatchE" in signature - ): - # This typevar being covariant isn't handled correctly in some cases, strip the + - # and insert the fully-qualified name. - signature = signature.replace("+E", "~trio.testing._raises_group.E") - signature = signature.replace( - "+MatchE", "~trio.testing._raises_group.MatchE" - ) if "DTLS" in name: signature = signature.replace("SSL.Context", "OpenSSL.SSL.Context") # Don't specify PathLike[str] | PathLike[bytes], this is just for humans. @@ -152,7 +177,15 @@ def setup(app: Sphinx) -> None: app.connect("autodoc-process-signature", autodoc_process_signature) # After Intersphinx runs, add additional mappings. app.connect("builder-inited", add_intersphinx, priority=1000) + app.connect("source-read", on_read_source) + +# Our docs use the READTHEDOCS variable, so copied from: +# https://about.readthedocs.com/blog/2024/07/addons-by-default/ +if os.environ.get("READTHEDOCS", "") == "True": + if "html_context" not in globals(): + html_context = {} + html_context["READTHEDOCS"] = True # -- General configuration ------------------------------------------------ @@ -182,6 +215,7 @@ def setup(app: Sphinx) -> None: "pyopenssl": ("https://www.pyopenssl.org/en/stable/", None), "sniffio": ("https://sniffio.readthedocs.io/en/latest/", None), "trio-util": ("https://trio-util.readthedocs.io/en/latest/", None), + "flake8-async": ("https://flake8-async.readthedocs.io/en/latest/", None), } # See https://sphinx-hoverxref.readthedocs.io/en/latest/configuration.html @@ -243,12 +277,14 @@ def add_mapping( # This has been removed in Py3.12, so add a link to the 3.11 version with deprecation warnings. add_mapping("method", "pathlib", "Path.link_to", "3.11") + # defined in py:data in objects.inv, but sphinx looks for a py:class + # see https://github.com/sphinx-doc/sphinx/issues/10974 + # to dump the objects.inv for the stdlib, you can run + # python -m sphinx.ext.intersphinx http://docs.python.org/3/objects.inv add_mapping("class", "math", "inf") - # `types.FrameType.__module__` is "builtins", so sphinx looks for - # builtins.FrameType. - # See https://github.com/sphinx-doc/sphinx/issues/11802 add_mapping("class", "types", "FrameType") + # new in py3.12, and need target because sphinx is unable to look up # the module of the object if compiling on <3.12 if not hasattr(collections.abc, "Buffer"): @@ -324,10 +360,7 @@ def add_mapping( # We have to set this ourselves, not only because it's useful for local # testing, but also because if we don't then RTD will throw away our # html_theme_options. -import sphinx_rtd_theme - html_theme = "sphinx_rtd_theme" -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index f37f57d5dd..cace2943d2 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -176,6 +176,24 @@ This keeps us closer to the desired state where each open issue reflects some work that still needs to be done. +Environment +~~~~~~~~~~~ +We strongly suggest using a virtual environment for managing dependencies, +for example with `venv `__. So to +set up your environment and install dependencies, you should run something like: + +.. code-block:: shell + + cd path/to/trio/checkout/ + python -m venv .venv # create virtual env in .venv + source .venv/bin/activate # activate it + pip install -e . # install trio, needed for pytest plugin + pip install -r test-requirements.txt # install test requirements + +you rarely need to recreate the virtual environment, but you need to re-activate it +in future terminals. You might also need to re-install from test-requirements.txt if +the versions in it get updated. + .. _pull-request-tests: Tests @@ -186,12 +204,11 @@ locally, you should run: .. code-block:: shell - cd path/to/trio/checkout/ - pip install -r test-requirements.txt # possibly using a virtualenv - pytest trio + source .venv/bin/activate # if not already activated + pytest src This doesn't try to be completely exhaustive โ€“ it only checks that -things work on your machine, and it may skip some slow tests. But it's +things work on your machine, and it will skip some slow tests. But it's a good way to quickly check that things seem to be working, and we'll automatically run the full test suite when your PR is submitted, so you'll have a chance to see and fix any remaining issues then. @@ -211,8 +228,14 @@ it being merely hard to fix). For example: We use Codecov to track coverage, because it makes it easy to combine coverage from running in different configurations. Running coverage locally can be useful -(``pytest --cov=PACKAGENAME --cov-report=html``), but don't be -surprised if you get lower coverage than when looking at Codecov + +.. code-block:: shell + + coverage run -m pytest + coverage combine + coverage report + +but don't be surprised if you get lower coverage than when looking at Codecov reports, because there are some lines that are only executed on Windows, or macOS, or PyPy, or CPython, or... you get the idea. After you create a PR, Codecov will automatically report back with the diff --git a/docs/source/history.rst b/docs/source/history.rst index 9cef5191e5..34e2fe9772 100644 --- a/docs/source/history.rst +++ b/docs/source/history.rst @@ -5,6 +5,63 @@ Release history .. towncrier release notes start +Trio 0.27.0 (2024-10-17) +------------------------ + +Breaking changes +~~~~~~~~~~~~~~~~ + +- :func:`trio.move_on_after` and :func:`trio.fail_after` previously set the deadline relative to initialization time, instead of more intuitively upon entering the context manager. This might change timeouts if a program relied on this behavior. If you want to restore previous behavior you should instead use ``trio.move_on_at(trio.current_time() + ...)``. + flake8-async has a new rule to catch this, in case you're supporting older trio versions. See :ref:`ASYNC122`. (`#2512 `__) + + +Features +~~~~~~~~ + +- :meth:`CancelScope.relative_deadline` and :meth:`CancelScope.is_relative` added, as well as a ``relative_deadline`` parameter to ``__init__``. This allows initializing scopes ahead of time, but where the specified relative deadline doesn't count down until the scope is entered. (`#2512 `__) +- :class:`trio.Lock` and :class:`trio.StrictFIFOLock` will now raise :exc:`trio.BrokenResourceError` when :meth:`trio.Lock.acquire` would previously stall due to the owner of the lock exiting without releasing the lock. (`#3035 `__) +- `trio.move_on_at`, `trio.move_on_after`, `trio.fail_at` and `trio.fail_after` now accept *shield* as a keyword argument. If specified, it provides an initial value for the `~trio.CancelScope.shield` attribute of the `trio.CancelScope` object created by the context manager. (`#3052 `__) +- Added :func:`trio.lowlevel.add_parking_lot_breaker` and :func:`trio.lowlevel.remove_parking_lot_breaker` to allow creating custom lock/semaphore implementations that will break their underlying parking lot if a task exits unexpectedly. :meth:`trio.lowlevel.ParkingLot.break_lot` is also added, to allow breaking a parking lot intentionally. (`#3081 `__) + + +Bugfixes +~~~~~~~~ + +- Allow sockets to bind any ``os.PathLike`` object. (`#3041 `__) +- Update ``trio.lowlevel.open_process``'s documentation to allow bytes. (`#3076 `__) +- Update :func:`trio.sleep_forever` to be `NoReturn`. (`#3095 `__) + + +Improved documentation +~~~~~~~~~~~~~~~~~~~~~~ + +- Add docstrings for memory channels' ``statistics()`` and ``aclose`` methods. (`#3101 `__) + + +Trio 0.26.2 (2024-08-08) +------------------------ + +Bugfixes +~~~~~~~~ + +- Remove remaining ``hash`` usage and fix test configuration issue that prevented it from being caught. (`#3053 `__) + + +Trio 0.26.1 (2024-08-05) +------------------------ + +Bugfixes +~~~~~~~~ + +- Switched ``attrs`` usage off of ``hash``, which is now deprecated. (`#3053 `__) + + +Miscellaneous internal changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Use PyPI's Trusted Publishers to make releases. (`#2980 `__) + + Trio 0.26.0 (2024-07-05) ------------------------ @@ -117,7 +174,7 @@ Trio 0.23.2 (2023-12-14) Features ~~~~~~~~ -- `TypeVarTuple `_ is now used to fully type :meth:`nursery.start_soon() `, :func:`trio.run()`, :func:`trio.to_thread.run_sync()`, and other similar functions accepting ``(func, *args)``. This means type checkers will be able to verify types are used correctly. :meth:`nursery.start() ` is not fully typed yet however. (`#2881 `__) +- `TypeVarTuple `_ is now used to fully type :meth:`nursery.start_soon() `, :func:`trio.run`, :func:`trio.to_thread.run_sync`, and other similar functions accepting ``(func, *args)``. This means type checkers will be able to verify types are used correctly. :meth:`nursery.start() ` is not fully typed yet however. (`#2881 `__) Bugfixes @@ -976,7 +1033,7 @@ Bugfixes - Fixed a race condition on macOS, where Trio's TCP listener would crash if an incoming TCP connection was closed before the listener had a chance to accept it. (`#609 `__) -- :func:`trio.open_tcp_stream()` has been refactored to clean up unsuccessful +- :func:`trio.open_tcp_stream` has been refactored to clean up unsuccessful connection attempts more reliably. (`#809 `__) diff --git a/docs/source/index.rst b/docs/source/index.rst index 1caf5c043b..b89bc4533c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -45,7 +45,7 @@ Vital statistics: * Supported environments: We test on - - Python: 3.8+ (CPython and PyPy) + - Python: 3.9+ (CPython and PyPy) - Windows, macOS, Linux (glibc and musl), FreeBSD Other environments might also work; give it a try and see. diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index 6808f930c6..a9bb3909d9 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -449,8 +449,7 @@ attribute to :data:`True`: try: await conn.send_hello_msg() finally: - with trio.move_on_after(CLEANUP_TIMEOUT) as cleanup_scope: - cleanup_scope.shield = True + with trio.move_on_after(CLEANUP_TIMEOUT, shield=True) as cleanup_scope: await conn.send_goodbye_msg() So long as you're inside a scope with ``shield = True`` set, then @@ -528,8 +527,12 @@ objects. .. autoattribute:: deadline + .. autoattribute:: relative_deadline + .. autoattribute:: shield + .. automethod:: is_relative() + .. automethod:: cancel() .. attribute:: cancelled_caught @@ -562,7 +565,8 @@ situation of just wanting to impose a timeout on some code: .. autofunction:: fail_at :with: cancel_scope -Cheat sheet: +Cheat sheet ++++++++++++ * If you want to impose a timeout on a function, but you don't care whether it timed out or not: @@ -598,7 +602,6 @@ which is sometimes useful: .. autofunction:: current_effective_deadline - .. _tasks: Tasks let you do multiple things at once @@ -1238,6 +1241,8 @@ more features beyond the core channel interface: .. autoclass:: MemoryReceiveChannel :members: +.. autoclass:: MemoryChannelStatistics + :members: A simple channel example ++++++++++++++++++++++++ diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index e8a967bf17..665f62dd0b 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -555,7 +555,8 @@ there are 1,000,000 ยตs in a second. Note that all the numbers here are going to be rough orders of magnitude to give you a sense of scale; if you need precise numbers for your environment, measure!) -.. file.read benchmark is notes-to-self/file-read-latency.py +.. file.read benchmark is + https://github.com/python-trio/trio/wiki/notes-to-self#file-read-latencypy .. Numbers for spinning disks and SSDs are from taking a few random recent reviews from http://www.storagereview.com/best_drives and looking at their "4K Write Latency" test results for "Average MS" diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 70133b9839..46c8b4d485 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -377,6 +377,46 @@ These transitions are accomplished using two function decorators: poorly-timed :exc:`KeyboardInterrupt` could leave the lock in an inconsistent state and cause a deadlock. + Since KeyboardInterrupt protection is tracked per code object, any attempt to + conditionally protect the same block of code in different ways is unlikely to behave + how you expect. If you try to conditionally protect a closure, it will be + unconditionally protected instead:: + + def example(protect: bool) -> bool: + def inner() -> bool: + return trio.lowlevel.currently_ki_protected() + if protect: + inner = trio.lowlevel.enable_ki_protection(inner) + return inner() + + async def amain(): + assert example(False) == False + assert example(True) == True # once protected ... + assert example(False) == True # ... always protected + + trio.run(amain) + + If you really need conditional protection, you can achieve it by giving each + KI-protected instance of the closure its own code object:: + + def example(protect: bool) -> bool: + def inner() -> bool: + return trio.lowlevel.currently_ki_protected() + if protect: + inner.__code__ = inner.__code__.replace() + inner = trio.lowlevel.enable_ki_protection(inner) + return inner() + + async def amain(): + assert example(False) == False + assert example(True) == True + assert example(False) == False + + trio.run(amain) + + (This isn't done by default because it carries some memory overhead and reduces + the potential for specializing optimizations in recent versions of CPython.) + .. autofunction:: currently_ki_protected @@ -393,6 +433,10 @@ Wait queue abstraction .. autoclass:: ParkingLotStatistics :members: +.. autofunction:: add_parking_lot_breaker + +.. autofunction:: remove_parking_lot_breaker + Low-level checkpoint functions ------------------------------ diff --git a/docs/source/releasing.rst b/docs/source/releasing.rst index e4cb70685d..c4d66b9c15 100644 --- a/docs/source/releasing.rst +++ b/docs/source/releasing.rst @@ -41,13 +41,7 @@ Things to do for releasing: * tag with vVERSION, push tag on ``python-trio/trio`` (not on your personal repository) -* push to PyPI: - - .. code-block:: - - git clean -xdf # maybe run 'git clean -xdn' first to see what it will delete - python3 -m build - twine upload dist/* +* approve the release workflow's publish job * update version number in the same pull request diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index c7218d873b..3ae9bb4597 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -88,7 +88,7 @@ Okay, ready? Let's get started. Before you begin ---------------- -1. Make sure you're using Python 3.8 or newer. +1. Make sure you're using Python 3.9 or newer. 2. ``python3 -m pip install --upgrade trio`` (or on Windows, maybe ``py -3 -m pip install --upgrade trio`` โ€“ `details diff --git a/newsfragments/2670.bugfix.rst b/newsfragments/2670.bugfix.rst new file mode 100644 index 0000000000..cd5ed3b944 --- /dev/null +++ b/newsfragments/2670.bugfix.rst @@ -0,0 +1,2 @@ +:func:`inspect.iscoroutinefunction` and the like now give correct answers when +called on KI-protected functions. diff --git a/newsfragments/3087.doc.rst b/newsfragments/3087.doc.rst new file mode 100644 index 0000000000..68fa4b05ed --- /dev/null +++ b/newsfragments/3087.doc.rst @@ -0,0 +1 @@ +Improve error message when run after gevent's monkey patching. diff --git a/newsfragments/3094.misc.rst b/newsfragments/3094.misc.rst new file mode 100644 index 0000000000..c35b7802e8 --- /dev/null +++ b/newsfragments/3094.misc.rst @@ -0,0 +1 @@ +Switch to using PEP570 for positional-only arguments for `~trio.socket.SocketType`'s methods. diff --git a/newsfragments/3097.removal.rst b/newsfragments/3097.removal.rst new file mode 100644 index 0000000000..1eca349d44 --- /dev/null +++ b/newsfragments/3097.removal.rst @@ -0,0 +1 @@ +Remove workaround for OpenSSL 1.1.1 DTLS ClientHello bug. diff --git a/newsfragments/3106.removal.rst b/newsfragments/3106.removal.rst new file mode 100644 index 0000000000..ee023242d0 --- /dev/null +++ b/newsfragments/3106.removal.rst @@ -0,0 +1 @@ +Drop support for Python 3.8. (`#3104 `__) diff --git a/newsfragments/3108.bugfix.rst b/newsfragments/3108.bugfix.rst new file mode 100644 index 0000000000..16cf46b960 --- /dev/null +++ b/newsfragments/3108.bugfix.rst @@ -0,0 +1,26 @@ +Rework KeyboardInterrupt protection to track code objects, rather than frames, +as protected or not. The new implementation no longer needs to access +``frame.f_locals`` dictionaries, so it won't artificially extend the lifetime of +local variables. Since KeyboardInterrupt protection is now imposed statically +(when a protected function is defined) rather than each time the function runs, +its previously-noticeable performance overhead should now be near zero. +The lack of a call-time wrapper has some other benefits as well: + +* :func:`inspect.iscoroutinefunction` and the like now give correct answers when + called on KI-protected functions. + +* Calling a synchronous KI-protected function no longer pushes an additional stack + frame, so tracebacks are clearer. + +* A synchronous KI-protected function invoked from C code (such as a weakref + finalizer) is now guaranteed to start executing; previously there would be a brief + window in which KeyboardInterrupt could be raised before the protection was + established. + +One minor drawback of the new approach is that multiple instances of the same +closure share a single KeyboardInterrupt protection state (because they share a +single code object). That means that if you apply +`@enable_ki_protection ` to some of them +and not others, you won't get the protection semantics you asked for. See the +documentation of `@enable_ki_protection ` +for more details and a workaround. diff --git a/newsfragments/3112.bugfix.rst b/newsfragments/3112.bugfix.rst new file mode 100644 index 0000000000..c34d035520 --- /dev/null +++ b/newsfragments/3112.bugfix.rst @@ -0,0 +1,5 @@ +Rework foreign async generator finalization to track async generator +ids rather than mutating ``ag_frame.f_locals``. This fixes an issue +with the previous implementation: locals' lifetimes will no longer be +extended by materialization in the ``ag_frame.f_locals`` dictionary that +the previous finalization dispatcher logic needed to access to do its work. diff --git a/newsfragments/3113.doc.rst b/newsfragments/3113.doc.rst new file mode 100644 index 0000000000..7ce407a29e --- /dev/null +++ b/newsfragments/3113.doc.rst @@ -0,0 +1 @@ +Document that :func:`trio.sleep_forever` is guaranteed to raise an exception now. diff --git a/newsfragments/3114.bugfix.rst b/newsfragments/3114.bugfix.rst new file mode 100644 index 0000000000..2f07712199 --- /dev/null +++ b/newsfragments/3114.bugfix.rst @@ -0,0 +1 @@ +Ensure that Pyright recognizes our underscore prefixed attributes for attrs classes. diff --git a/newsfragments/3121.misc.rst b/newsfragments/3121.misc.rst new file mode 100644 index 0000000000..731232877b --- /dev/null +++ b/newsfragments/3121.misc.rst @@ -0,0 +1 @@ +Improve type annotations in several places by removing `Any` usage. diff --git a/newsfragments/3141.bugfix.rst b/newsfragments/3141.bugfix.rst new file mode 100644 index 0000000000..36d378d5a3 --- /dev/null +++ b/newsfragments/3141.bugfix.rst @@ -0,0 +1 @@ +Fix `trio.testing.RaisesGroup`'s typing. diff --git a/newsfragments/3159.misc.rst b/newsfragments/3159.misc.rst new file mode 100644 index 0000000000..9460e11c65 --- /dev/null +++ b/newsfragments/3159.misc.rst @@ -0,0 +1 @@ +Get and enforce 100% coverage diff --git a/notes-to-self/afd-lab.py b/notes-to-self/afd-lab.py deleted file mode 100644 index 600975482c..0000000000 --- a/notes-to-self/afd-lab.py +++ /dev/null @@ -1,182 +0,0 @@ -# A little script to experiment with AFD polling. -# -# This cheats and uses a bunch of internal APIs. Don't follow its example. The -# point is just to experiment with random junk that probably won't work, so we -# can figure out what we actually do want to do internally. - -# Currently this demonstrates what seems to be a weird bug in the Windows -# kernel. If you: -# -# 0. Set up a socket so that it's not writable. -# 1. Submit a SEND poll operation. -# 2. Submit a RECEIVE poll operation. -# 3. Send some data through the socket, to trigger the RECEIVE. -# -# ...then the SEND poll operation completes with the RECEIVE flag set. -# -# (This bug is why our Windows backend jumps through hoops to avoid ever -# issuing multiple polls simultaneously for the same socket.) -# -# This script's output on my machine: -# -# -- Iteration start -- -# Starting a poll for -# Starting a poll for -# Sending another byte -# Poll for : got -# Poll for : Cancelled() -# -- Iteration start -- -# Starting a poll for -# Starting a poll for -# Poll for : got Sending another byte -# Poll for : got -# -# So what we're seeing is: -# -# On the first iteration, where there's initially no data in the socket, the -# SEND completes with the RECEIVE flag set, and the RECEIVE operation doesn't -# return at all, until we cancel it. -# -# On the second iteration, there's already data sitting in the socket from the -# last loop. This time, the RECEIVE returns immediately with the RECEIVE flag -# set, which makes sense -- when starting a RECEIVE poll, it does an immediate -# check to see if there's data already, and if so it does an early exit. But -# the bizarre thing is, when we then send *another* byte of data, the SEND -# operation wakes up with the RECEIVE flag set. -# -# Why is this bizarre? Let me count the ways: -# -# - The SEND operation should never return RECEIVE. -# -# - If it does insist on returning RECEIVE, it should do it immediately, since -# there is already data to receive. But it doesn't. -# -# - And then when we send data into a socket that already has data in it, that -# shouldn't have any effect at all! But instead it wakes up the SEND. -# -# - Also, the RECEIVE call did an early check for data and exited out -# immediately, without going through the whole "register a callback to -# be notified when data arrives" dance. So even if you do have some bug -# in tracking which operations should be woken on which state transitions, -# there's no reason this operation would even touch that tracking data. Yet, -# if we take out the brief RECEIVE, then the SEND *doesn't* wake up. -# -# - Also, if I move the send() call up above the loop, so that there's already -# data in the socket when we start our first iteration, then you would think -# that would just make the first iteration act like it was the second -# iteration. But it doesn't. Instead it makes all the weird behavior -# disappear entirely. -# -# "What do we know โ€ฆ of the world and the universe about us? Our means of -# receiving impressions are absurdly few, and our notions of surrounding -# objects infinitely narrow. We see things only as we are constructed to see -# them, and can gain no idea of their absolute nature. With five feeble senses -# we pretend to comprehend the boundlessly complex cosmos, yet other beings -# with wider, stronger, or different range of senses might not only see very -# differently the things we see, but might see and study whole worlds of -# matter, energy, and life which lie close at hand yet can never be detected -# with the senses we have." - -import os.path -import sys - -sys.path.insert(0, os.path.abspath(os.path.dirname(__file__) + r"\..")) - -import trio - -print(trio.__file__) -import socket - -import trio.testing -from trio._core._io_windows import _afd_helper_handle, _check, _get_base_socket -from trio._core._windows_cffi import ( - AFDPollFlags, - ErrorCodes, - IoControlCodes, - ffi, - kernel32, -) - - -class AFDLab: - def __init__(self): - self._afd = _afd_helper_handle() - trio.lowlevel.register_with_iocp(self._afd) - - async def afd_poll(self, sock, flags, *, exclusive=0): - print(f"Starting a poll for {flags!r}") - lpOverlapped = ffi.new("LPOVERLAPPED") - poll_info = ffi.new("AFD_POLL_INFO *") - poll_info.Timeout = 2**63 - 1 # INT64_MAX - poll_info.NumberOfHandles = 1 - poll_info.Exclusive = exclusive - poll_info.Handles[0].Handle = _get_base_socket(sock) - poll_info.Handles[0].Status = 0 - poll_info.Handles[0].Events = flags - - try: - _check( - kernel32.DeviceIoControl( - self._afd, - IoControlCodes.IOCTL_AFD_POLL, - poll_info, - ffi.sizeof("AFD_POLL_INFO"), - poll_info, - ffi.sizeof("AFD_POLL_INFO"), - ffi.NULL, - lpOverlapped, - ) - ) - except OSError as exc: - if exc.winerror != ErrorCodes.ERROR_IO_PENDING: # pragma: no cover - raise - - try: - await trio.lowlevel.wait_overlapped(self._afd, lpOverlapped) - except: - print(f"Poll for {flags!r}: {sys.exc_info()[1]!r}") - raise - out_flags = AFDPollFlags(poll_info.Handles[0].Events) - print(f"Poll for {flags!r}: got {out_flags!r}") - return out_flags - - -def fill_socket(sock): - try: - while True: - sock.send(b"x" * 65536) - except BlockingIOError: - pass - - -async def main(): - afdlab = AFDLab() - - a, b = socket.socketpair() - a.setblocking(False) - b.setblocking(False) - - fill_socket(a) - - while True: - print("-- Iteration start --") - async with trio.open_nursery() as nursery: - nursery.start_soon( - afdlab.afd_poll, - a, - AFDPollFlags.AFD_POLL_SEND, - ) - await trio.sleep(2) - nursery.start_soon( - afdlab.afd_poll, - a, - AFDPollFlags.AFD_POLL_RECEIVE, - ) - await trio.sleep(2) - print("Sending another byte") - b.send(b"x") - await trio.sleep(2) - nursery.cancel_scope.cancel() - - -trio.run(main) diff --git a/notes-to-self/aio-guest-test.py b/notes-to-self/aio-guest-test.py deleted file mode 100644 index 3c607d0281..0000000000 --- a/notes-to-self/aio-guest-test.py +++ /dev/null @@ -1,53 +0,0 @@ -import asyncio - -import trio - - -async def aio_main(): - loop = asyncio.get_running_loop() - - trio_done_fut = loop.create_future() - - def trio_done_callback(main_outcome): - print(f"trio_main finished: {main_outcome!r}") - trio_done_fut.set_result(main_outcome) - - trio.lowlevel.start_guest_run( - trio_main, - run_sync_soon_threadsafe=loop.call_soon_threadsafe, - done_callback=trio_done_callback, - ) - - (await trio_done_fut).unwrap() - - -async def trio_main(): - print("trio_main!") - - to_trio, from_aio = trio.open_memory_channel(float("inf")) - from_trio = asyncio.Queue() - - _task_ref = asyncio.create_task(aio_pingpong(from_trio, to_trio)) - - from_trio.put_nowait(0) - - async for n in from_aio: - print(f"trio got: {n}") - await trio.sleep(1) - from_trio.put_nowait(n + 1) - if n >= 10: - return - del _task_ref - - -async def aio_pingpong(from_trio, to_trio): - print("aio_pingpong!") - - while True: - n = await from_trio.get() - print(f"aio got: {n}") - await asyncio.sleep(1) - to_trio.send_nowait(n + 1) - - -asyncio.run(aio_main()) diff --git a/notes-to-self/atomic-local.py b/notes-to-self/atomic-local.py deleted file mode 100644 index 643bc16c6a..0000000000 --- a/notes-to-self/atomic-local.py +++ /dev/null @@ -1,35 +0,0 @@ -from types import CodeType - -# Has to be a string :-( -sentinel = "_unique_name" - - -def f(): - print(locals()) - - -# code(argcount, kwonlyargcount, nlocals, stacksize, flags, codestring, -# constants, names, varnames, filename, name, firstlineno, -# lnotab[, freevars[, cellvars]]) -new_code = CodeType( - f.__code__.co_argcount, - f.__code__.co_kwonlyargcount + 1, - f.__code__.co_nlocals + 1, - f.__code__.co_stacksize, - f.__code__.co_flags, - f.__code__.co_code, - f.__code__.co_consts, - f.__code__.co_names, - (*f.__code__.co_varnames, sentinel), - f.__code__.co_filename, - f.__code__.co_name, - f.__code__.co_firstlineno, - f.__code__.co_lnotab, - f.__code__.co_freevars, - f.__code__.co_cellvars, -) - -f.__code__ = new_code -f.__kwdefaults__ = {sentinel: "fdsa"} - -f() diff --git a/notes-to-self/blocking-read-hack.py b/notes-to-self/blocking-read-hack.py deleted file mode 100644 index 56bcd03df9..0000000000 --- a/notes-to-self/blocking-read-hack.py +++ /dev/null @@ -1,51 +0,0 @@ -import errno -import os -import socket - -import trio - -bad_socket = socket.socket() - - -class BlockingReadTimeoutError(Exception): - pass - - -async def blocking_read_with_timeout( - fd, count, timeout # noqa: ASYNC109 # manual timeout -): - print("reading from fd", fd) - cancel_requested = False - - async def kill_it_after_timeout(new_fd): - print("sleeping") - await trio.sleep(timeout) - print("breaking the fd") - os.dup2(bad_socket.fileno(), new_fd, inheritable=False) - # MAGIC - print("setuid(getuid())") - os.setuid(os.getuid()) - nonlocal cancel_requested - cancel_requested = True - - new_fd = os.dup(fd) - print("working fd is", new_fd) - try: - async with trio.open_nursery() as nursery: - nursery.start_soon(kill_it_after_timeout, new_fd) - try: - data = await trio.to_thread.run_sync(os.read, new_fd, count) - except OSError as exc: - if cancel_requested and exc.errno == errno.ENOTCONN: - # Call was successfully cancelled. In a real version we'd - # integrate properly with Trio's cancellation tools; here - # we'll just raise an arbitrary error. - raise BlockingReadTimeoutError from None - print("got", data) - nursery.cancel_scope.cancel() - return data - finally: - os.close(new_fd) - - -trio.run(blocking_read_with_timeout, 0, 10, 2) diff --git a/notes-to-self/estimate-task-size.py b/notes-to-self/estimate-task-size.py deleted file mode 100644 index 0010c7a2b4..0000000000 --- a/notes-to-self/estimate-task-size.py +++ /dev/null @@ -1,33 +0,0 @@ -# Little script to get a rough estimate of how much memory each task takes - -import resource - -import trio -import trio.testing - -LOW = 1000 -HIGH = 10000 - - -async def tinytask(): - await trio.sleep_forever() - - -async def measure(count): - async with trio.open_nursery() as nursery: - for _ in range(count): - nursery.start_soon(tinytask) - await trio.testing.wait_all_tasks_blocked() - nursery.cancel_scope.cancel() - return resource.getrusage(resource.RUSAGE_SELF) - - -async def main(): - low_usage = await measure(LOW) - high_usage = await measure(HIGH + LOW) - - print("Memory usage per task:", (high_usage.ru_maxrss - low_usage.ru_maxrss) / HIGH) - print("(kilobytes on Linux, bytes on macOS)") - - -trio.run(main) diff --git a/notes-to-self/fbsd-pipe-close-notify.py b/notes-to-self/fbsd-pipe-close-notify.py deleted file mode 100644 index ef60d6900e..0000000000 --- a/notes-to-self/fbsd-pipe-close-notify.py +++ /dev/null @@ -1,37 +0,0 @@ -# This script completes correctly on macOS and FreeBSD 13.0-CURRENT, but hangs -# on FreeBSD 12.1. I'm told the fix will be backported to 12.2 (which is due -# out in October 2020). -# -# Upstream bug: https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=246350 - -import os -import select - -r, w = os.pipe() - -os.set_blocking(w, False) - -print("filling pipe buffer") -try: - while True: - os.write(w, b"x") -except BlockingIOError: - pass - -_, wfds, _ = select.select([], [w], [], 0) -print("select() says the write pipe is", "writable" if w in wfds else "NOT writable") - -kq = select.kqueue() -event = select.kevent(w, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) -kq.control([event], 0) - -print("closing read end of pipe") -os.close(r) - -_, wfds, _ = select.select([], [w], [], 0) -print("select() says the write pipe is", "writable" if w in wfds else "NOT writable") - -print("waiting for kqueue to report the write end is writable") -got = kq.control([], 1) -print("done!") -print(got) diff --git a/notes-to-self/file-read-latency.py b/notes-to-self/file-read-latency.py deleted file mode 100644 index e82b027a03..0000000000 --- a/notes-to-self/file-read-latency.py +++ /dev/null @@ -1,27 +0,0 @@ -import time - -# https://bitbucket.org/pypy/pypy/issues/2624/weird-performance-on-pypy3-when-reading -# COUNT = 100000 -# f = open("/etc/passwd", "rt") -COUNT = 1000000 -# With default buffering this test never even syscalls, and goes at about ~140 -# ns per call, instead of ~500 ns/call for the syscall and related overhead. -# That's probably more fair -- the BufferedIOBase code can't service random -# accesses, even if your working set fits entirely in RAM. -with open("/etc/passwd", "rb") as f: # , buffering=0) - while True: - start = time.perf_counter() - for _ in range(COUNT): - f.seek(0) - f.read(1) - between = time.perf_counter() - for _ in range(COUNT): - f.seek(0) - end = time.perf_counter() - - both = (between - start) / COUNT * 1e9 - seek = (end - between) / COUNT * 1e9 - read = both - seek - print( - f"{both:.2f} ns/(seek+read), {seek:.2f} ns/seek, estimate ~{read:.2f} ns/read" - ) diff --git a/notes-to-self/graceful-shutdown-idea.py b/notes-to-self/graceful-shutdown-idea.py deleted file mode 100644 index 9497af9724..0000000000 --- a/notes-to-self/graceful-shutdown-idea.py +++ /dev/null @@ -1,66 +0,0 @@ -import signal - -import gsm -import trio - - -class GracefulShutdownManager: - def __init__(self): - self._shutting_down = False - self._cancel_scopes = set() - - def start_shutdown(self): - self._shutting_down = True - for cancel_scope in self._cancel_scopes: - cancel_scope.cancel() - - def cancel_on_graceful_shutdown(self): - cancel_scope = trio.CancelScope() - self._cancel_scopes.add(cancel_scope) - if self._shutting_down: - cancel_scope.cancel() - return cancel_scope - - @property - def shutting_down(self): - return self._shutting_down - - -# Code can check gsm.shutting_down occasionally at appropriate points to see -# if it should exit. -# -# When doing operations that might block for an indefinite about of time and -# that should be aborted when a graceful shutdown starts, wrap them in 'with -# gsm.cancel_on_graceful_shutdown()'. -async def stream_handler(stream): - while True: - with gsm.cancel_on_graceful_shutdown(): - data = await stream.receive_some() - print(f"{data = }") - if gsm.shutting_down: - break - - -# To trigger the shutdown: -async def listen_for_shutdown_signals(): - with trio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as signal_aiter: - async for _sig in signal_aiter: - gsm.start_shutdown() - break - # TODO: it'd be nice to have some logic like "if we get another - # signal, or if 30 seconds pass, then do a hard shutdown". - # That's easy enough: - # - # with trio.move_on_after(30): - # async for sig in signal_aiter: - # break - # sys.exit() - # - # The trick is, if we do finish shutting down in (say) 10 seconds, - # then we want to exit immediately. So I guess you'd need the main - # part of the program to detect when it's finished shutting down, and - # then cancel listen_for_shutdown_signals? - # - # I guess this would be a good place to use @smurfix's daemon task - # construct: - # https://github.com/python-trio/trio/issues/569#issuecomment-408419260 diff --git a/notes-to-self/how-does-windows-so-reuseaddr-work.py b/notes-to-self/how-does-windows-so-reuseaddr-work.py deleted file mode 100644 index 70dd75e39f..0000000000 --- a/notes-to-self/how-does-windows-so-reuseaddr-work.py +++ /dev/null @@ -1,76 +0,0 @@ -# There are some tables here: -# https://web.archive.org/web/20120206195747/https://msdn.microsoft.com/en-us/library/windows/desktop/ms740621(v=vs.85).aspx -# They appear to be wrong. -# -# See https://github.com/python-trio/trio/issues/928 for details and context - -import errno -import socket - -modes = ["default", "SO_REUSEADDR", "SO_EXCLUSIVEADDRUSE"] -bind_types = ["wildcard", "specific"] - - -def sock(mode): - s = socket.socket(family=socket.AF_INET) - if mode == "SO_REUSEADDR": - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - elif mode == "SO_EXCLUSIVEADDRUSE": - s.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) - return s - - -def bind(sock, bind_type): - if bind_type == "wildcard": - sock.bind(("0.0.0.0", 12345)) - elif bind_type == "specific": - sock.bind(("127.0.0.1", 12345)) - else: - raise AssertionError() - - -def table_entry(mode1, bind_type1, mode2, bind_type2): - with sock(mode1) as sock1: - bind(sock1, bind_type1) - try: - with sock(mode2) as sock2: - bind(sock2, bind_type2) - except OSError as exc: - if exc.winerror == errno.WSAEADDRINUSE: - return "INUSE" - elif exc.winerror == errno.WSAEACCES: - return "ACCESS" - raise - else: - return "Success" - - -print( - """ - second bind - | """ - + " | ".join(["%-19s" % mode for mode in modes]) -) - -print(""" """, end="") -for _ in modes: - print(" | " + " | ".join(["%8s" % bind_type for bind_type in bind_types]), end="") - -print( - """ -first bind -----------------------------------------------------------------""" - # default | wildcard | INUSE | Success | ACCESS | Success | INUSE | Success -) - -for mode1 in modes: - for bind_type1 in bind_types: - row = [] - for mode2 in modes: - for bind_type2 in bind_types: - entry = table_entry(mode1, bind_type1, mode2, bind_type2) - row.append(entry) - # print(mode1, bind_type1, mode2, bind_type2, entry) - print( - f"{mode1:>19} | {bind_type1:>8} | " - + " | ".join(["%8s" % entry for entry in row]) - ) diff --git a/notes-to-self/loopy.py b/notes-to-self/loopy.py deleted file mode 100644 index 99f6e050b9..0000000000 --- a/notes-to-self/loopy.py +++ /dev/null @@ -1,23 +0,0 @@ -import time - -import trio - - -async def loopy(): - try: - while True: - # synchronous sleep to avoid maxing out CPU - time.sleep(0.01) # noqa: ASYNC251 - await trio.lowlevel.checkpoint() - except KeyboardInterrupt: - print("KI!") - - -async def main(): - async with trio.open_nursery() as nursery: - nursery.start_soon(loopy) - nursery.start_soon(loopy) - nursery.start_soon(loopy) - - -trio.run(main) diff --git a/notes-to-self/lots-of-tasks.py b/notes-to-self/lots-of-tasks.py deleted file mode 100644 index 048c69a7ec..0000000000 --- a/notes-to-self/lots-of-tasks.py +++ /dev/null @@ -1,15 +0,0 @@ -import sys - -import trio - -(COUNT_STR,) = sys.argv[1:] -COUNT = int(COUNT_STR) - - -async def main(): - async with trio.open_nursery() as nursery: - for _ in range(COUNT): - nursery.start_soon(trio.sleep, 1) - - -trio.run(main) diff --git a/notes-to-self/manual-signal-handler.py b/notes-to-self/manual-signal-handler.py deleted file mode 100644 index d865fa89ee..0000000000 --- a/notes-to-self/manual-signal-handler.py +++ /dev/null @@ -1,24 +0,0 @@ -# How to manually call the SIGINT handler on Windows without using raise() or -# similar. -import os -import sys - -if os.name == "nt": - import cffi - - ffi = cffi.FFI() - ffi.cdef( - """ - void* WINAPI GetProcAddress(void* hModule, char* lpProcName); - typedef void (*PyOS_sighandler_t)(int); - """ - ) - kernel32 = ffi.dlopen("kernel32.dll") - PyOS_getsig_ptr = kernel32.GetProcAddress( - ffi.cast("void*", sys.dllhandle), b"PyOS_getsig" - ) - PyOS_getsig = ffi.cast("PyOS_sighandler_t (*)(int)", PyOS_getsig_ptr) - - import signal - - PyOS_getsig(signal.SIGINT)(signal.SIGINT) diff --git a/notes-to-self/measure-listen-backlog.py b/notes-to-self/measure-listen-backlog.py deleted file mode 100644 index b7253b86cc..0000000000 --- a/notes-to-self/measure-listen-backlog.py +++ /dev/null @@ -1,28 +0,0 @@ -import trio - - -async def run_test(nominal_backlog): - print("--\nnominal:", nominal_backlog) - - listen_sock = trio.socket.socket() - await listen_sock.bind(("127.0.0.1", 0)) - listen_sock.listen(nominal_backlog) - client_socks = [] - while True: - client_sock = trio.socket.socket() - # Generally the response to the listen buffer being full is that the - # SYN gets dropped, and the client retries after 1 second. So we - # assume that any connect() call to localhost that takes >0.5 seconds - # indicates a dropped SYN. - with trio.move_on_after(0.5) as cancel_scope: - await client_sock.connect(listen_sock.getsockname()) - if cancel_scope.cancelled_caught: - break - client_socks.append(client_sock) - print("actual:", len(client_socks)) - for client_sock in client_socks: - client_sock.close() - - -for nominal_backlog in [10, trio.socket.SOMAXCONN, 65535]: - trio.run(run_test, nominal_backlog) diff --git a/notes-to-self/ntp-example.py b/notes-to-self/ntp-example.py deleted file mode 100644 index 2bb9f80fb3..0000000000 --- a/notes-to-self/ntp-example.py +++ /dev/null @@ -1,96 +0,0 @@ -# If you want to use IPv6, then: -# - replace AF_INET with AF_INET6 everywhere -# - use the hostname "2.pool.ntp.org" -# (see: https://news.ntppool.org/2011/06/continuing-ipv6-deployment/) - -import datetime -import struct - -import trio - - -def make_query_packet(): - """Construct a UDP packet suitable for querying an NTP server to ask for - the current time.""" - - # The structure of an NTP packet is described here: - # https://tools.ietf.org/html/rfc5905#page-19 - # They're always 48 bytes long, unless you're using extensions, which we - # aren't. - packet = bytearray(48) - - # The first byte contains 3 subfields: - # first 2 bits: 11, leap second status unknown - # next 3 bits: 100, NTP version indicator, 0b100 == 4 = version 4 - # last 3 bits: 011, NTP mode indicator, 0b011 == 3 == "client" - packet[0] = 0b11100011 - - # For an outgoing request, all other fields can be left as zeros. - - return packet - - -def extract_transmit_timestamp(ntp_packet): - """Given an NTP packet, extract the "transmit timestamp" field, as a - Python datetime.""" - - # The transmit timestamp is the time that the server sent its response. - # It's stored in bytes 40-47 of the NTP packet. See: - # https://tools.ietf.org/html/rfc5905#page-19 - encoded_transmit_timestamp = ntp_packet[40:48] - - # The timestamp is stored in the "NTP timestamp format", which is a 32 - # byte count of whole seconds, followed by a 32 byte count of fractions of - # a second. See: - # https://tools.ietf.org/html/rfc5905#page-13 - seconds, fraction = struct.unpack("!II", encoded_transmit_timestamp) - - # The timestamp is the number of seconds since January 1, 1900 (ignoring - # leap seconds). To convert it to a datetime object, we do some simple - # datetime arithmetic: - base_time = datetime.datetime(1900, 1, 1) - offset = datetime.timedelta(seconds=seconds + fraction / 2**32) - return base_time + offset - - -async def main(): - print("Our clock currently reads (in UTC):", datetime.datetime.utcnow()) - - # Look up some random NTP servers. - # (See www.pool.ntp.org for information about the NTP pool.) - servers = await trio.socket.getaddrinfo( - "pool.ntp.org", # host - "ntp", # port - family=trio.socket.AF_INET, # IPv4 - type=trio.socket.SOCK_DGRAM, # UDP - ) - - # Construct an NTP query packet. - query_packet = make_query_packet() - - # Create a UDP socket - udp_sock = trio.socket.socket( - family=trio.socket.AF_INET, # IPv4 - type=trio.socket.SOCK_DGRAM, # UDP - ) - - # Use the socket to send the query packet to each of the servers. - print("-- Sending queries --") - for server in servers: - address = server[-1] - print("Sending to:", address) - await udp_sock.sendto(query_packet, address) - - # Read responses from the socket. - print("-- Reading responses (for 10 seconds) --") - with trio.move_on_after(10): - while True: - # We accept packets up to 1024 bytes long (though in practice NTP - # packets will be much shorter). - data, address = await udp_sock.recvfrom(1024) - print("Got response from:", address) - transmit_timestamp = extract_transmit_timestamp(data) - print("Their clock read (in UTC):", transmit_timestamp) - - -trio.run(main) diff --git a/notes-to-self/print-task-tree.py b/notes-to-self/print-task-tree.py deleted file mode 100644 index 54b97ec014..0000000000 --- a/notes-to-self/print-task-tree.py +++ /dev/null @@ -1,113 +0,0 @@ -# NOTE: -# possibly it would be easier to use https://pypi.org/project/tree-format/ -# instead of formatting by hand like this code does... - -""" -Demo/exploration of how to print a task tree. Outputs: - - -โ”œโ”€ __main__.main -โ”‚ โ”œโ”€ __main__.child1 -โ”‚ โ”‚ โ”œโ”€ trio.sleep_forever -โ”‚ โ”‚ โ”œโ”€ __main__.child2 -โ”‚ โ”‚ โ”‚ โ”œโ”€ trio.sleep_forever -โ”‚ โ”‚ โ”‚ โ””โ”€ trio.sleep_forever -โ”‚ โ”‚ โ””โ”€ __main__.child2 -โ”‚ โ”‚ โ”œโ”€ trio.sleep_forever -โ”‚ โ”‚ โ””โ”€ trio.sleep_forever -โ”‚ โ””โ”€ (nested nursery) -โ”‚ โ””โ”€ __main__.child1 -โ”‚ โ”œโ”€ trio.sleep_forever -โ”‚ โ”œโ”€ __main__.child2 -โ”‚ โ”‚ โ”œโ”€ trio.sleep_forever -โ”‚ โ”‚ โ””โ”€ trio.sleep_forever -โ”‚ โ””โ”€ __main__.child2 -โ”‚ โ”œโ”€ trio.sleep_forever -โ”‚ โ””โ”€ trio.sleep_forever -โ””โ”€ - -""" - -import trio -import trio.testing - -MID_PREFIX = "โ”œโ”€ " -MID_CONTINUE = "โ”‚ " -END_PREFIX = "โ””โ”€ " -END_CONTINUE = " " * len(END_PREFIX) - - -def current_root_task(): - task = trio.lowlevel.current_task() - while task.parent_nursery is not None: - task = task.parent_nursery.parent_task - return task - - -def _render_subtree(name, rendered_children): - lines = [] - lines.append(name) - for child_lines in rendered_children: - if child_lines is rendered_children[-1]: - first_prefix = END_PREFIX - rest_prefix = END_CONTINUE - else: - first_prefix = MID_PREFIX - rest_prefix = MID_CONTINUE - lines.append(first_prefix + child_lines[0]) - lines.extend(rest_prefix + child_line for child_line in child_lines[1:]) - return lines - - -def _rendered_nursery_children(nursery): - return [task_tree_lines(t) for t in nursery.child_tasks] - - -def task_tree_lines(task=None): - if task is None: - task = current_root_task() - rendered_children = [] - nurseries = list(task.child_nurseries) - while nurseries: - nursery = nurseries.pop() - nursery_children = _rendered_nursery_children(nursery) - if rendered_children: - nested = _render_subtree("(nested nursery)", rendered_children) - nursery_children.append(nested) - rendered_children = nursery_children - return _render_subtree(task.name, rendered_children) - - -def print_task_tree(task=None): - for line in task_tree_lines(task): - print(line) - - -################################################################ - - -async def child2(): - async with trio.open_nursery() as nursery: - nursery.start_soon(trio.sleep_forever) - nursery.start_soon(trio.sleep_forever) - - -async def child1(): - async with trio.open_nursery() as nursery: - nursery.start_soon(child2) - nursery.start_soon(child2) - nursery.start_soon(trio.sleep_forever) - - -async def main(): - async with trio.open_nursery() as nursery0: - nursery0.start_soon(child1) - async with trio.open_nursery() as nursery1: - nursery1.start_soon(child1) - - await trio.testing.wait_all_tasks_blocked() - print_task_tree() - nursery0.cancel_scope.cancel() - - -trio.run(main) diff --git a/notes-to-self/proxy-benchmarks.py b/notes-to-self/proxy-benchmarks.py deleted file mode 100644 index d28909a347..0000000000 --- a/notes-to-self/proxy-benchmarks.py +++ /dev/null @@ -1,175 +0,0 @@ -import textwrap -import time - -methods = {"fileno"} - - -class Proxy1: - strategy = "__getattr__" - works_for = "any attr" - - def __init__(self, wrapped): - self._wrapped = wrapped - - def __getattr__(self, name): - if name in methods: - return getattr(self._wrapped, name) - raise AttributeError(name) - - -################################################################ - - -class Proxy2: - strategy = "generated methods (getattr + closure)" - works_for = "methods" - - def __init__(self, wrapped): - self._wrapped = wrapped - - -def add_wrapper(cls, method): - def wrapper(self, *args, **kwargs): - return getattr(self._wrapped, method)(*args, **kwargs) - - setattr(cls, method, wrapper) - - -for method in methods: - add_wrapper(Proxy2, method) - -################################################################ - - -class Proxy3: - strategy = "generated methods (exec)" - works_for = "methods" - - def __init__(self, wrapped): - self._wrapped = wrapped - - -def add_wrapper(cls, method): - code = textwrap.dedent( - f""" - def wrapper(self, *args, **kwargs): - return self._wrapped.{method}(*args, **kwargs) - """ - ) - ns = {} - exec(code, ns) - setattr(cls, method, ns["wrapper"]) - - -for method in methods: - add_wrapper(Proxy3, method) - -################################################################ - - -class Proxy4: - strategy = "generated properties (getattr + closure)" - works_for = "any attr" - - def __init__(self, wrapped): - self._wrapped = wrapped - - -def add_wrapper(cls, attr): - def getter(self): - return getattr(self._wrapped, attr) - - def setter(self, newval): - setattr(self._wrapped, attr, newval) - - def deleter(self): - delattr(self._wrapped, attr) - - setattr(cls, attr, property(getter, setter, deleter)) - - -for method in methods: - add_wrapper(Proxy4, method) - -################################################################ - - -class Proxy5: - strategy = "generated properties (exec)" - works_for = "any attr" - - def __init__(self, wrapped): - self._wrapped = wrapped - - -def add_wrapper(cls, attr): - code = textwrap.dedent( - f""" - def getter(self): - return self._wrapped.{attr} - - def setter(self, newval): - self._wrapped.{attr} = newval - - def deleter(self): - del self._wrapped.{attr} - """ - ) - ns = {} - exec(code, ns) - setattr(cls, attr, property(ns["getter"], ns["setter"], ns["deleter"])) - - -for method in methods: - add_wrapper(Proxy5, method) - -################################################################ - - -# methods only -class Proxy6: - strategy = "copy attrs from wrappee to wrapper" - works_for = "methods + constant attrs" - - def __init__(self, wrapper): - self._wrapper = wrapper - - for method in methods: - setattr(self, method, getattr(self._wrapper, method)) - - -################################################################ - -classes = [Proxy1, Proxy2, Proxy3, Proxy4, Proxy5, Proxy6] - - -def check(cls): - with open("/etc/passwd") as f: - p = cls(f) - assert p.fileno() == f.fileno() - - -for cls in classes: - check(cls) - -with open("/etc/passwd") as f: - objs = [c(f) for c in classes] - - COUNT = 1000000 - try: - import __pypy__ # noqa: F401 # __pypy__ imported but unused - except ImportError: - pass - else: - COUNT *= 10 - - while True: - print("-------") - for obj in objs: - start = time.perf_counter() - for _ in range(COUNT): - obj.fileno() - # obj.fileno - end = time.perf_counter() - per_usec = COUNT / (end - start) / 1e6 - print(f"{per_usec:7.2f} / us: {obj.strategy} ({obj.works_for})") diff --git a/notes-to-self/reopen-pipe.py b/notes-to-self/reopen-pipe.py deleted file mode 100644 index dbccd567d7..0000000000 --- a/notes-to-self/reopen-pipe.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -import tempfile -import threading -import time - - -def check_reopen(r1, w): - try: - print("Reopening read end") - r2 = os.open(f"/proc/self/fd/{r1}", os.O_RDONLY) - - print(f"r1 is {r1}, r2 is {r2}") - - print("checking they both can receive from w...") - - os.write(w, b"a") - assert os.read(r1, 1) == b"a" - - os.write(w, b"b") - assert os.read(r2, 1) == b"b" - - print("...ok") - - print("setting r2 to non-blocking") - os.set_blocking(r2, False) - - print("os.get_blocking(r1) ==", os.get_blocking(r1)) - print("os.get_blocking(r2) ==", os.get_blocking(r2)) - - # Check r2 is really truly non-blocking - try: - os.read(r2, 1) - except BlockingIOError: - print("r2 definitely seems to be in non-blocking mode") - - # Check that r1 is really truly still in blocking mode - def sleep_then_write(): - time.sleep(1) - os.write(w, b"c") - - threading.Thread(target=sleep_then_write, daemon=True).start() - assert os.read(r1, 1) == b"c" - print("r1 definitely seems to be in blocking mode") - except Exception as exc: - print(f"ERROR: {exc!r}") - - -print("-- testing anonymous pipe --") -check_reopen(*os.pipe()) - -print("-- testing FIFO --") -with tempfile.TemporaryDirectory() as tmpdir: - fifo = tmpdir + "/" + "myfifo" - os.mkfifo(fifo) - # "A process can open a FIFO in nonblocking mode. In this case, opening - # for read-only will succeed even if no-one has opened on the write side - # yet and opening for write-only will fail with ENXIO (no such device or - # address) unless the other end has already been opened." -- Linux fifo(7) - r = os.open(fifo, os.O_RDONLY | os.O_NONBLOCK) - assert not os.get_blocking(r) - os.set_blocking(r, True) - assert os.get_blocking(r) - w = os.open(fifo, os.O_WRONLY) - check_reopen(r, w) - -print("-- testing socketpair --") -import socket - -rs, ws = socket.socketpair() -check_reopen(rs.fileno(), ws.fileno()) diff --git a/notes-to-self/schedule-timing.py b/notes-to-self/schedule-timing.py deleted file mode 100644 index 11594b7cc7..0000000000 --- a/notes-to-self/schedule-timing.py +++ /dev/null @@ -1,42 +0,0 @@ -import time - -import trio - -LOOPS = 0 -RUNNING = True - - -async def reschedule_loop(depth): - if depth == 0: - global LOOPS - while RUNNING: - LOOPS += 1 - await trio.lowlevel.checkpoint() - # await trio.lowlevel.cancel_shielded_checkpoint() - else: - await reschedule_loop(depth - 1) - - -async def report_loop(): - global RUNNING - try: - while True: - start_count = LOOPS - start_time = time.perf_counter() - await trio.sleep(1) - end_time = time.perf_counter() - end_count = LOOPS - loops = end_count - start_count - duration = end_time - start_time - print(f"{loops / duration} loops/sec") - finally: - RUNNING = False - - -async def main(): - async with trio.open_nursery() as nursery: - nursery.start_soon(reschedule_loop, 10) - nursery.start_soon(report_loop) - - -trio.run(main) diff --git a/notes-to-self/server.crt b/notes-to-self/server.crt deleted file mode 100644 index 9c58d8e65b..0000000000 --- a/notes-to-self/server.crt +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDBjCCAe4CCQDq+3W9D8C4ejANBgkqhkiG9w0BAQsFADBFMQswCQYDVQQGEwJB -VTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0 -cyBQdHkgTHRkMB4XDTE3MDMxOTAzMDk1MVoXDTE4MDMxOTAzMDk1MVowRTELMAkG -A1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0 -IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB -AOwDDFVh8pvIrhZtIIX6pb3/PO5SM3rWsfoyyHi73GxemIiEHfYEjMKN8Eo10jUv -4G0n8VlrrmuhGR+UuHY6jCxjoCuYWszQhwBZBaeGE24ydtO/RE24yhNsJHPQWXMe -TL4mg1EBjJYXTwNhd7SwgCpkBQ+724ZJg+CmiPuYhVLdvjjUUmwiSbeueyULIPEJ -G1EWkKdU5pYtyyTZoc0x2YEjes3YNWY563yk+RljvidFBMyAX8N3fF4yrCCHDeY6 -UPdpXry/BJcEJm7PY2lMhbL71T6499qKnmSaWyJjm+KqbXSEYXoWDVBBvg5pR9Ia -XSoJ1MTfJ8eYnZDs5mETYDkCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEApaW8WiKA -3yDOUVzgwkeX3HxvxfhxtMTPBmO1M8YgX1yi+URkkKakc6bg3XW1saQrxBkXWwBr -81Atd0tOLwHsC1HPd7Y5Q/1LKZiYFq2Sva6eZfeedRF/0f/SQC+rSvZNI5DIVPS4 -jW/EpyMKIeerIyWeFXz0/NWcYLCDWN6m2iDtR3m98bJcqSdUemLgyR13EAWsaVZ7 -dB6nkwGl9e78SOIHeGYg1Fb0B7IN2Tqw2tO3Xn0mzhvqs65OYuYo4pB0FzxiySAB -q2nrgu6kGhkQw/RQ8QJ5MYjydYqCU0I4Qve1W7RoUxRnJvxJrMuvcdlMeboASKNl -L7YQurFGvAAiZQ== ------END CERTIFICATE----- diff --git a/notes-to-self/server.csr b/notes-to-self/server.csr deleted file mode 100644 index f0fbc3829d..0000000000 --- a/notes-to-self/server.csr +++ /dev/null @@ -1,16 +0,0 @@ ------BEGIN CERTIFICATE REQUEST----- -MIICijCCAXICAQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx -ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN -AQEBBQADggEPADCCAQoCggEBAOwDDFVh8pvIrhZtIIX6pb3/PO5SM3rWsfoyyHi7 -3GxemIiEHfYEjMKN8Eo10jUv4G0n8VlrrmuhGR+UuHY6jCxjoCuYWszQhwBZBaeG -E24ydtO/RE24yhNsJHPQWXMeTL4mg1EBjJYXTwNhd7SwgCpkBQ+724ZJg+CmiPuY -hVLdvjjUUmwiSbeueyULIPEJG1EWkKdU5pYtyyTZoc0x2YEjes3YNWY563yk+Rlj -vidFBMyAX8N3fF4yrCCHDeY6UPdpXry/BJcEJm7PY2lMhbL71T6499qKnmSaWyJj -m+KqbXSEYXoWDVBBvg5pR9IaXSoJ1MTfJ8eYnZDs5mETYDkCAwEAAaAAMA0GCSqG -SIb3DQEBCwUAA4IBAQC+LhkPmCjxk5Nzn743u+7D/YzNhjv8Xv4aGUjjNyspVNso -tlCAWkW2dWo8USvQrMUz5yl6qj6QQlg0QaYfaIiK8pkGz4s+Sh1plz1Eaa7QDK4O -0wmtP6KkJyQW561ZY8sixS1DevKOmsp2Pa9fWU/vqKfzRv85A975XNadp6hkxXd7 -YOZCrSZjTnakpQvKoItvT9Xk7yKP6BI6h/03XORscbW/HyvLGoVLdE80yIkmjSot -3JXxHspT27bWNWhz/Slph3UFaVyOVGXFTAqkLDZ3OISMnuC+q/t38EHYkR1aev/l -4WogCtlWkFZ3bmhmlhJrH/bdTEkM6WopwoC6bczh ------END CERTIFICATE REQUEST----- diff --git a/notes-to-self/server.key b/notes-to-self/server.key deleted file mode 100644 index c0ba0b8582..0000000000 --- a/notes-to-self/server.key +++ /dev/null @@ -1,27 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIIEogIBAAKCAQEA7AMMVWHym8iuFm0ghfqlvf887lIzetax+jLIeLvcbF6YiIQd -9gSMwo3wSjXSNS/gbSfxWWuua6EZH5S4djqMLGOgK5hazNCHAFkFp4YTbjJ2079E -TbjKE2wkc9BZcx5MviaDUQGMlhdPA2F3tLCAKmQFD7vbhkmD4KaI+5iFUt2+ONRS -bCJJt657JQsg8QkbURaQp1Tmli3LJNmhzTHZgSN6zdg1ZjnrfKT5GWO+J0UEzIBf -w3d8XjKsIIcN5jpQ92levL8ElwQmbs9jaUyFsvvVPrj32oqeZJpbImOb4qptdIRh -ehYNUEG+DmlH0hpdKgnUxN8nx5idkOzmYRNgOQIDAQABAoIBABuus9ij93fsTvcU -b7cnUh95+6ScgatL2W5WXItExbL0WYHRtU3w9K2xRlj9/Rz98536DGYHqlq3d6Hr -qMM9VMm0GcpjQWs6nksdJfujT04inytxCMrw/MrQaWooKwXErQ20qLxsqRfFvh/Q -Y+EOvsm6F5nj1/jlUJGeFv0jw6eXXxH6bqVUVYIaVCpAMB5Sm8caQ4dAI9UESZJv -vuucT24iSyV8vp060L1tNKgRUr5e2CMfbucauZh0nLALPAyu1I07Ce62q9wLLw66 -c2FLHcZBkTGvL0bPe89ttJJuK0jttHV6GQ/OneytezZFxLw1DMsG3VxzbXt2X7AN -noGzrDECgYEA/fnK0xlNir9bTNmOID42GoVUYF6iYWUf//rRlCKsRofSlPjTZDJK -grl/plTBKDE6qDDEkB1mrEkJufqP3slyq66NfkP0NLoo+PFkGSnsbvUvFNYwcYvH -7w2NWo/GvM4DJRqHvrETryBQwQtBJFsq9biWd3+hNCXYrhawKGqbzw0CgYEA7eSa -T6zIdmvszG5x1XzQ3k29GwUA4SLL1YfV2pnLxoMZgW5Q6T+cOCpJdEuL0BXCNelP -gk0gNXNvCzylwVC0BbpefFaJYsWK6gVg1EwDkiZcGx4FnKd0TWYer6RWrZ9cVohT -eNwix9kKVef7chf+2006eE1O8D0UYwZMpGifqt0CgYAKjmtjwtV6QuHkm9ZQeMV+ -7LPJHaXaLn3aAe7cHWTTuamDD6SZsY1vSY6It1Uf+ovZmc1RwCcYWiDRXhzEwdLG -WAcBjImF94bkcgQbF6cAJajDUPPKhGjXAtUxQnCcQGPZEvU5c9rBmLJCk9ktTazH -cdivNtrYdApBkifYRjYbsQKBgDZl0ctqTSSXJTzR/IG+2twalqV5DWxt0oJvXz1v -caNhExH/sczEWOqW8NkA9WWNtC0zvpSjIjxWuwuswJJl6+Rra3OvLhdB6LP+qteg -0ig3UVR6FvptaDDSqy2qvI9TI4A+CChY3jMotC5Ur7C1P/fRvw8HToesz96c8CWg -LvKZAoGAS4VW2YaYVlN14mkKGtHbZq0XWZGursAwYJn6H5jBCjl+89tsFzfCnTER -hZFH+zs281v1bVbCFws8lWrZg8AZh55dG8CcLtuCkTyXJ/aAdlan0+FmXV60+RLP -Z1TyykQG/oDgO1Z+5GrcN2b0FOFaSbH2NRzRlhyOI63yTQi4lT8= ------END RSA PRIVATE KEY----- diff --git a/notes-to-self/server.orig.key b/notes-to-self/server.orig.key deleted file mode 100644 index 28fac173ff..0000000000 --- a/notes-to-self/server.orig.key +++ /dev/null @@ -1,30 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -Proc-Type: 4,ENCRYPTED -DEK-Info: DES-EDE3-CBC,D9C5B2214855387C - -gYuCZiXsU74IOErbOGOmc1y/BFP1N7UuRO19tidUrq1O6sreJSAVKRibIynAwmXj -p5xvaAnBBIZIH6X7I2vduJgtUeeyvy5yxR98pD6liRKDxFaVD+O1m5IZxSbAs2De -olk4Zlv3YULpbVF6Ud+QuLmgbqfmT+8NVGm4MwRey7Gkj+LEfGrNjpfgLqNRIaUZ -XDPQh9HLZYCsAbz5OeRHwJwawLO74fvWBkFjsQyoWLgJzqZFmt15SyrRufBeKYP6 -oKKemsiW8/A2+i6Rb1vHYOJJ6c9jeeHPJkZSbfNWf4/Z702DAMIisbHmTCQzsUrX -178d2Z7sDKcuDCQ1EInnLRb3YET/V83wGDWyHxWepaHLWHd5S7tFbsqZFsXxIuYM -lcZZVSPsOLnG2SozZK+Tr2RX7jkI4Kmfh0RDgtKBYQQopZjRSFUG2hvvH3EIxVIf -JyUG8AA5RT1J9tkcSJJ5MS40So7i3eyAuZXuYVSkuDai/mu2IUU8vYnwB8a1psU1 -P2CGUj2AFopvMAfSOYIPGHpcIn+lvxuXUdczR/Yikp/BhGT+diJjP68CUsMBdyq7 -pcVmMVyQPVpcsMag3IXGgIAF1v1GhO3zDMd1uXA1lyrHQa6CEah3z+4WFSWwYZ0I -OZz5qM9bnfKoAQesp+xmcZhs8cbrblMRVDkWiPUixxKVJk3eBUsMoa1WYq/2u0ly -EgvNlY39B/3eiLi+k+S6gVGT8a4AP6n4RuxPD4g0A79bI1LpC3xU6vcKV/GyIP69 -t2DHR2q9UDEiRj9DxjuTzxew7eRX8ktD7DhYV06BxrgQIRRiL9MrZRKGuqzXcMP/ -kWY71ioFZJ1ViZkpy7bEsYrpF2XBjGge3We0s2udnrY3r3ogxOjtZiT27e2zEbXD -T59C3gecuEzCSCZ3eQtdcVC9m3RdHMTNNKvqmTVFPgfGOoM5u2gG+rYjhetbpTDB -T5drcEEAcV11DHuokU4tlqOdIWdLuBsK3xgO98JasEr1LyYJT1fnjB+6AbhjfSS2 -p5TPekmSwaZbaBzwfP1xmhINJm388GCROXMkc9iLAWN9npHhssfMAA2WMXqDTgSt -34oUnHgLGmvOm5HzJE/tTR1WP1Rye4nKNLwsbk2x7WXxqcNUPYc+OVmZbsl/R5Gz -3zRHPts01mT/eaSfqj1wkJgpYtDQLPO+V1fc2pDgJmQMYyr7OCLI6I9GJBlB8gVq -aemv0TMi3/eUVyJRaAHxAAi7YMsrSkuKUrsbLfIIRgViaEy+1stFa9iWiHJT0DKJ -0fOqtwcL8OYJURyG/D29yUP5qBJcrFuIYk8uI1wtfDNMeAI4LWoWwMhBLtB6POY+ -a/qmMewFzrGGsR9R0ptwtlplhvJVeArfLYGngnbgBV4vwchjLQTR2RMouZWlwRH9 -NWX6EqsIk/zzYvu+o7sBC2839D3GCPQMmgKqSWwmlf2a76mqZk2duTO9+0v6+e+F -Qc44ndLFE+mEibXkm9PMHvPsXOUdC4KPpugC/aZbn4OCqVd3eSl7k+PZGKZua6IJ -ybhosNzQc4lg25K7iMxRXpK5WrOgEXSAA3kUquDRTWHshpz/Avwbgw== ------END RSA PRIVATE KEY----- diff --git a/notes-to-self/sleep-time.py b/notes-to-self/sleep-time.py deleted file mode 100644 index 85adf623e1..0000000000 --- a/notes-to-self/sleep-time.py +++ /dev/null @@ -1,61 +0,0 @@ -# Suppose: -# - we're blocked until a timeout occurs -# - our process gets put to sleep for a while (SIGSTOP or whatever) -# - then it gets woken up again -# what happens to our timeout? -# -# Here we do things that sleep for 6 seconds, and we put the process to sleep -# for 2 seconds in the middle of that. -# -# Results on Linux: everything takes 6 seconds, except for select.select(), -# and also time.sleep() (which on CPython uses the select() call internally) -# -# Results on macOS: everything takes 6 seconds. -# -# Why do we care: -# https://github.com/python-trio/trio/issues/591#issuecomment-498020805 - -import os -import select -import signal -import subprocess -import sys -import time - -DUR = 6 -# Can also try SIGTSTP -STOP_SIGNAL = signal.SIGSTOP - -test_progs = [ - f"import threading; ev = threading.Event(); ev.wait({DUR})", - # Python's time.sleep() calls select() internally - f"import time; time.sleep({DUR})", - # This is the real sleep() function - f"import ctypes; ctypes.CDLL(None).sleep({DUR})", - f"import select; select.select([], [], [], {DUR})", - f"import select; p = select.poll(); p.poll({DUR} * 1000)", -] -if hasattr(select, "epoll"): - test_progs += [ - f"import select; ep = select.epoll(); ep.poll({DUR})", - ] -if hasattr(select, "kqueue"): - test_progs += [ - f"import select; kq = select.kqueue(); kq.control([], 1, {DUR})", - ] - -for test_prog in test_progs: - print("----------------------------------------------------------------") - start = time.monotonic() - print(f"Running: {test_prog}") - print(f"Expected duration: {DUR} seconds") - p = subprocess.Popen([sys.executable, "-c", test_prog]) - time.sleep(DUR / 3) - print(f"Putting it to sleep for {DUR / 3} seconds") - os.kill(p.pid, STOP_SIGNAL) - time.sleep(DUR / 3) - print("Waking it up again") - os.kill(p.pid, signal.SIGCONT) - p.wait() - end = time.monotonic() - print(f"Actual duration: {end - start:.2f}") diff --git a/notes-to-self/socket-scaling.py b/notes-to-self/socket-scaling.py deleted file mode 100644 index bd7e32ef7f..0000000000 --- a/notes-to-self/socket-scaling.py +++ /dev/null @@ -1,60 +0,0 @@ -# Little script to measure how wait_readable scales with the number of -# sockets. We look at three key measurements: -# -# - cost of issuing wait_readable -# - cost of running the scheduler, while wait_readables are blocked in the -# background -# - cost of cancelling wait_readable -# -# On Linux and macOS, these all appear to be ~O(1), as we'd expect. -# -# On Windows: with the old 'select'-based loop, the cost of scheduling grew -# with the number of outstanding sockets, which was bad. -# -# To run this on Unix systems, you'll probably first have to run: -# -# ulimit -n 31000 -# -# or similar. - -import socket -import time - -import trio -import trio.testing - - -async def main(): - for total in [10, 100, 500, 1_000, 10_000, 20_000, 30_000]: - - def pt(desc, *, count=total, item="socket"): - nonlocal last_time - now = time.perf_counter() - total_ms = (now - last_time) * 1000 - per_us = total_ms * 1000 / count - print(f"{desc}: {total_ms:.2f} ms total, {per_us:.2f} ยตs/{item}") - last_time = now - - print(f"\n-- {total} sockets --") - last_time = time.perf_counter() - sockets = [] - for _ in range(total // 2): - a, b = socket.socketpair() - sockets += [a, b] - pt("socket creation") - async with trio.open_nursery() as nursery: - for s in sockets: - nursery.start_soon(trio.lowlevel.wait_readable, s) - await trio.testing.wait_all_tasks_blocked() - pt("spawning wait tasks") - for _ in range(1000): - await trio.lowlevel.cancel_shielded_checkpoint() - pt("scheduling 1000 times", count=1000, item="schedule") - nursery.cancel_scope.cancel() - pt("cancelling wait tasks") - for sock in sockets: - sock.close() - pt("closing sockets") - - -trio.run(main) diff --git a/notes-to-self/socketpair-buffering.py b/notes-to-self/socketpair-buffering.py deleted file mode 100644 index e6169c25d3..0000000000 --- a/notes-to-self/socketpair-buffering.py +++ /dev/null @@ -1,37 +0,0 @@ -import socket - -# Linux: -# low values get rounded up to ~2-4 KB, so that's predictable -# with low values, can queue up 6 one-byte sends (!) -# with default values, can queue up 278 one-byte sends -# -# Windows: -# if SNDBUF = 0 freezes, so that's useless -# by default, buffers 655121 -# with both set to 1, buffers 525347 -# except sometimes it's less intermittently (?!?) -# -# macOS: -# if bufsize = 1, can queue up 1 one-byte send -# with default bufsize, can queue up 8192 one-byte sends -# and bufsize = 0 is invalid (setsockopt errors out) - -for bufsize in [1, None, 0]: - a, b = socket.socketpair() - a.setblocking(False) - b.setblocking(False) - - a.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - if bufsize is not None: - a.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, bufsize) - b.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, bufsize) - - try: - for _count in range(10000000): - a.send(b"\x00") - except BlockingIOError: - break - - print(f"setsockopt bufsize {bufsize}: {_count}") - a.close() - b.close() diff --git a/notes-to-self/ssl-close-notify/ssl-close-notify.py b/notes-to-self/ssl-close-notify/ssl-close-notify.py deleted file mode 100644 index 7a55b8c99c..0000000000 --- a/notes-to-self/ssl-close-notify/ssl-close-notify.py +++ /dev/null @@ -1,78 +0,0 @@ -# Scenario: -# - TLS connection is set up successfully -# - client sends close_notify then closes socket -# - server receives the close_notify then attempts to send close_notify back -# -# On CPython, the last step raises BrokenPipeError. On PyPy, it raises -# SSLEOFError. -# -# SSLEOFError seems a bit perverse given that it's supposed to mean "EOF -# occurred in violation of protocol", and the client's behavior here is -# explicitly allowed by the RFCs. But maybe openssl is just perverse like -# that, and it's a coincidence that CPython and PyPy act differently here? I -# don't know if this is a bug or not. -# -# (Using: debian's CPython 3.5 or 3.6, and pypy3 5.8.0-beta) - -import socket -import ssl -import threading - -client_sock, server_sock = socket.socketpair() - -client_done = threading.Event() - - -def server_thread_fn(): - server_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - server_ctx.load_cert_chain("trio-test-1.pem") - server = server_ctx.wrap_socket( - server_sock, - server_side=True, - suppress_ragged_eofs=False, - ) - while True: - data = server.recv(4096) - print("server got:", data) - if not data: - print("server waiting for client to finish everything") - client_done.wait() - print("server attempting to send back close-notify") - server.unwrap() - print("server ok") - break - server.sendall(data) - - -server_thread = threading.Thread(target=server_thread_fn) -server_thread.start() - -client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") -client = client_ctx.wrap_socket(client_sock, server_hostname="trio-test-1.example.org") - - -# Now we have two SSLSockets that have established an encrypted connection -# with each other - -assert client.getpeercert() is not None -client.sendall(b"x") -assert client.recv(10) == b"x" - -# The client sends a close-notify, and then immediately closes the connection -# (as explicitly permitted by the TLS RFCs). - -# This is a slightly odd construction, but if you trace through the ssl module -# far enough you'll see that it's equivalent to calling SSL_shutdown() once, -# which generates the close_notify, and then immediately calling it again, -# which checks for the close_notify and then immediately raises -# SSLWantReadError because of course it hasn't arrived yet: -print("client sending close_notify") -client.setblocking(False) -try: - client.unwrap() -except ssl.SSLWantReadError: - print("client got SSLWantReadError as expected") -else: - raise AssertionError() -client.close() -client_done.set() diff --git a/notes-to-self/ssl-close-notify/ssl2.py b/notes-to-self/ssl-close-notify/ssl2.py deleted file mode 100644 index 54ee1fb9b6..0000000000 --- a/notes-to-self/ssl-close-notify/ssl2.py +++ /dev/null @@ -1,63 +0,0 @@ -# This demonstrates a PyPy bug: -# https://bitbucket.org/pypy/pypy/issues/2578/ - -import socket -import ssl -import threading - -# client_sock, server_sock = socket.socketpair() -listen_sock = socket.socket() -listen_sock.bind(("127.0.0.1", 0)) -listen_sock.listen(1) -client_sock = socket.socket() -client_sock.connect(listen_sock.getsockname()) -server_sock, _ = listen_sock.accept() - -server_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) -server_ctx.load_cert_chain("trio-test-1.pem") -server = server_ctx.wrap_socket( - server_sock, - server_side=True, - suppress_ragged_eofs=False, - do_handshake_on_connect=False, -) - -client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") -client = client_ctx.wrap_socket( - client_sock, - server_hostname="trio-test-1.example.org", - suppress_ragged_eofs=False, - do_handshake_on_connect=False, -) - -server_handshake_thread = threading.Thread(target=server.do_handshake) -server_handshake_thread.start() -client_handshake_thread = threading.Thread(target=client.do_handshake) -client_handshake_thread.start() - -server_handshake_thread.join() -client_handshake_thread.join() - -# Now we have two SSLSockets that have established an encrypted connection -# with each other - -assert client.getpeercert() is not None -client.sendall(b"x") -assert server.recv(10) == b"x" - -# A few different ways to make attempts to read/write the socket's fd return -# weird failures at the operating system level - -# Attempting to send on a socket after shutdown should raise EPIPE or similar -server.shutdown(socket.SHUT_WR) - -# Attempting to read/write to the fd after it's closed should raise EBADF -# os.close(server.fileno()) - -# Attempting to read/write to an fd opened with O_DIRECT raises EINVAL in most -# cases (unless you're very careful with alignment etc. which openssl isn't) -# os.dup2(os.open("/tmp/blah-example-file", os.O_RDWR | os.O_CREAT | os.O_DIRECT), server.fileno()) - -# Sending or receiving -server.sendall(b"hello") -# server.recv(10) diff --git a/notes-to-self/ssl-close-notify/trio-test-1.pem b/notes-to-self/ssl-close-notify/trio-test-1.pem deleted file mode 100644 index a0c1b773f9..0000000000 --- a/notes-to-self/ssl-close-notify/trio-test-1.pem +++ /dev/null @@ -1,64 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDZ8yz1OrHX7aHp -Erfa1ds8kmYfqYomjgy5wDsGdb8i1gF4uxhHCRDQtNZANVOVXI7R3TMchA1GMxzA -ZYDuBDuEUsqTktbTEBNb4GjOhyMu1fF4dX/tMxf7GB+flTx178eE2exTOZLmSmBa -2laoDVe3CrBAYE7nZtBF630jKKKMsUuIl0CbFRHajpoqM3e3CeCo4KcbBzgujRA3 -AsVV6y5qhMH2zqLkOYaurVUfEkdjqoHFgj1VbjWpkTbrXAxPwW6v/uZK056bHgBg -go03RyWexaPapsF2oUm2JNdSN3z7MP0umKphO2n9icyGt9Bmkm2AKs3dA45VLPXh -+NohluqJAgMBAAECggEARlfWCtAG1ko8F52S+W5MdCBMFawCiq8OLGV+p3cZWYT4 -tJ6uFz81ziaPf+m2MF7POazK8kksf5u/i9k245s6GlseRsL90uE9XknvibjUAinK -5bYGs+fptYDzs+3WtbnOC3LKc5IBd5JJxwjxLwwfY1RvzldHIChu0CJRISfcTsvR -occ8hXdeft7svNymvTuwQd05u1yjzL4RwF8Be76i17j5+jDsrAaUKdxxwGNAyOU7 -OKrUY6G851T6NUGgC19iXAJ1wN9tVGIR5QOs3J/s6dCctnX5tN8Di7prkXCKvVlm -vhpC8XWWG+c3LhS90wmEBvKS0AfUeoPDHxMOLyzKgQKBgQD07lZRO0nsc38+PVaI -NrvlP90Q8OgbwMIC52jmSZK3b5YSh3TrllsbCg6hzUk1SAJsa3qi7B1vq36Fd+rG -LGDRW9xY0cfShLhzqvZWi45zU/RYnEcWHOuXQshLikx1DWUpg2KbLSVT2/lyvzmn -QgM1Te8CSxW5vrBRVfluXoJuEwKBgQDjzLAbwk/wdjITKlQtirtsJEzWi3LGuUrg -Z2kMz+0ztUU5d1oFL9B5xh0CwK8bpK9kYnoVZSy/r5+mGHqyz1eKaDdAXIR13nC0 -g7aZbTZzbt2btvuNZc3NCzRffHF3sCqp8a+oCryHyITjZcA+WYeU8nG0TQ5O8Zgr -Skbo1JGocwKBgQC4jCx1oFqe0pd5afYdREBnB6ul7B63apHEZmBfw+fMV0OYSoAK -Uovq37UOrQMQJmXNE16gC5BSZ8E5B5XaI+3/UVvBgK8zK9VfMd3Sb+yxcPyXF4lo -W/oXSrZoVJgvShyDHv/ZNDb/7KsTjon+QHryWvpPnAuOnON1JXZ/dq6ICQKBgCZF -AukG8esR0EPL/qxP/ECksIvyjWu5QU0F0m4mmFDxiRmoZWUtrTZoBAOsXz6johuZ -N61Ue/oQBSAgSKy1jJ1h+LZFVLOAlSqeXhTUditaWryINyaADdz+nuPTwjQ7Uk+O -nNX8R8P/+eNB+tP+snphaJzDvT2h9NCA//ypiXblAoGAJoLmotPI+P3KIRVzESL0 -DAsVmeijtXE3H+R4nwqUDQbBbFKx0/u2pbON+D5C9llaGiuUp9H+awtwQRYhToeX -CNguwWrcpuhFOCeXDHDWF/0NIZYD2wBMxjF/eUarvoLaT4Gi0yyWh5ExIKOW4bFk -EojUPSJ3gomOUp5bIFcSmSU= ------END PRIVATE KEY----- ------BEGIN CERTIFICATE----- -MIICrzCCAZcCAQEwDQYJKoZIhvcNAQELBQAwFzEVMBMGA1UECgwMVHJpbyB0ZXN0 -IENBMCAXDTE3MDQwOTEwMDcyMVoYDzIyOTEwMTIyMTAwNzIxWjAiMSAwHgYDVQQD -DBd0cmlvLXRlc3QtMS5leGFtcGxlLm9yZzCCASIwDQYJKoZIhvcNAQEBBQADggEP -ADCCAQoCggEBANnzLPU6sdftoekSt9rV2zySZh+piiaODLnAOwZ1vyLWAXi7GEcJ -ENC01kA1U5VcjtHdMxyEDUYzHMBlgO4EO4RSypOS1tMQE1vgaM6HIy7V8Xh1f+0z -F/sYH5+VPHXvx4TZ7FM5kuZKYFraVqgNV7cKsEBgTudm0EXrfSMoooyxS4iXQJsV -EdqOmiozd7cJ4KjgpxsHOC6NEDcCxVXrLmqEwfbOouQ5hq6tVR8SR2OqgcWCPVVu -NamRNutcDE/Bbq/+5krTnpseAGCCjTdHJZ7Fo9qmwXahSbYk11I3fPsw/S6YqmE7 -af2JzIa30GaSbYAqzd0DjlUs9eH42iGW6okCAwEAATANBgkqhkiG9w0BAQsFAAOC -AQEAlRNA96H88lVnzlpQUYt0pwpoy7B3/CDe8Uvl41thKEfTjb+SIo95F4l+fi+l -jISWSonAYXRMNqymPMXl2ir0NigxfvvrcjggER3khASIs0l1ICwTNTv2a40NnFY6 -ZjTaBeSZ/lAi7191AkENDYvMl3aGhb6kALVIbos4/5LvJYF/UXvQfrjriLWZq/I3 -WkvduU9oSi0EA4Jt9aAhblsgDHMBL0+LU8Nl1tgzy2/NePcJWjzBRQDlF8uxCQ+2 -LesZongKQ+lebS4eYbNs0s810h8hrOEcn7VWn7FfxZRkjeaKIst2FCHmdr5JJgxj -8fw+s7l2UkrNURAJ4IRNQvPB+w== ------END CERTIFICATE----- ------BEGIN CERTIFICATE----- -MIIDBjCCAe6gAwIBAgIJAIUF+wna+nuzMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV -BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MDkxMDA3MjFaGA8yMjkxMDEyMjEwMDcy -MVowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC -AQ8AMIIBCgKCAQEAyhE82Cbq+2c2f9M+vj2f9v+z+0+bMZDUVPSXhBDiRdKubt+K -f9vY+ZH3ze1sm0iNgO6xU3OsDTlzO5z0TpsvEEbs0wgsJDUXD7Y8Fb1zH2jaVCro -Y6KcVfFZvD96zsVCnZy0vMsYJw20iIL0RNCtr17lXWVxd17OoVy91djFD9v/cixu -LRIr+N7pa8BDLUQUO/g0ui9YSC9Wgf67mr93KXKPGwjTHBGdjeZeex198j5aZjZR -lkPH/9g5d3hP7EI0EAIMDVd4dvwNJgZzv+AZINbKLAkQyE9AAm+xQ7qBSvdfAvKq -N/fwaFevmyrxUBcfoQxSpds8njWDb3dQzCn7ywIDAQABo1MwUTAdBgNVHQ4EFgQU -JiilveozF8Qpyy2fS3wV4foVRCswHwYDVR0jBBgwFoAUJiilveozF8Qpyy2fS3wV -4foVRCswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAkcthPjkJ -+npvMKyQtUx7CSVcT4Ar0jHrvfPSg17ipyLv+MhSTIbS2VwhSYXxNmu+oBWKuYUs -BnNxly3+SOcs+dTP3GBMng91SBsz5hhbP4ws8uUtCvYJauzeHbeY67R14RT8Ws/b -mP6HDiybN7zy6LOKGCiz+sCoJqVZG/yBYO87iQsTTyNttgoG27yUSvzP07EQwUa5 -F9dI9Wn+4b5wP2ofMCu3asTbKXjfFbz3w5OkRgpGYhC4jhDdOw/819+01R9//GrM -54Gme03yDAAM7nGihr1Xtld3dp2gLuqv0WgxKBqvG5X+nCbr2WamscAP5qz149vo -y6Hq6P4mm2GmZw== ------END CERTIFICATE----- diff --git a/notes-to-self/ssl-close-notify/trio-test-CA.pem b/notes-to-self/ssl-close-notify/trio-test-CA.pem deleted file mode 100644 index 9bf34001b2..0000000000 --- a/notes-to-self/ssl-close-notify/trio-test-CA.pem +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDBjCCAe6gAwIBAgIJAIUF+wna+nuzMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV -BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MDkxMDA3MjFaGA8yMjkxMDEyMjEwMDcy -MVowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC -AQ8AMIIBCgKCAQEAyhE82Cbq+2c2f9M+vj2f9v+z+0+bMZDUVPSXhBDiRdKubt+K -f9vY+ZH3ze1sm0iNgO6xU3OsDTlzO5z0TpsvEEbs0wgsJDUXD7Y8Fb1zH2jaVCro -Y6KcVfFZvD96zsVCnZy0vMsYJw20iIL0RNCtr17lXWVxd17OoVy91djFD9v/cixu -LRIr+N7pa8BDLUQUO/g0ui9YSC9Wgf67mr93KXKPGwjTHBGdjeZeex198j5aZjZR -lkPH/9g5d3hP7EI0EAIMDVd4dvwNJgZzv+AZINbKLAkQyE9AAm+xQ7qBSvdfAvKq -N/fwaFevmyrxUBcfoQxSpds8njWDb3dQzCn7ywIDAQABo1MwUTAdBgNVHQ4EFgQU -JiilveozF8Qpyy2fS3wV4foVRCswHwYDVR0jBBgwFoAUJiilveozF8Qpyy2fS3wV -4foVRCswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAkcthPjkJ -+npvMKyQtUx7CSVcT4Ar0jHrvfPSg17ipyLv+MhSTIbS2VwhSYXxNmu+oBWKuYUs -BnNxly3+SOcs+dTP3GBMng91SBsz5hhbP4ws8uUtCvYJauzeHbeY67R14RT8Ws/b -mP6HDiybN7zy6LOKGCiz+sCoJqVZG/yBYO87iQsTTyNttgoG27yUSvzP07EQwUa5 -F9dI9Wn+4b5wP2ofMCu3asTbKXjfFbz3w5OkRgpGYhC4jhDdOw/819+01R9//GrM -54Gme03yDAAM7nGihr1Xtld3dp2gLuqv0WgxKBqvG5X+nCbr2WamscAP5qz149vo -y6Hq6P4mm2GmZw== ------END CERTIFICATE----- diff --git a/notes-to-self/ssl-handshake/ssl-handshake.py b/notes-to-self/ssl-handshake/ssl-handshake.py deleted file mode 100644 index e906bc2a87..0000000000 --- a/notes-to-self/ssl-handshake/ssl-handshake.py +++ /dev/null @@ -1,139 +0,0 @@ -import socket -import ssl -import threading -from contextlib import contextmanager - -BUFSIZE = 4096 - -server_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) -server_ctx.load_cert_chain("trio-test-1.pem") - - -def _ssl_echo_serve_sync(sock): - try: - wrapped = server_ctx.wrap_socket(sock, server_side=True) - while True: - data = wrapped.recv(BUFSIZE) - if not data: - wrapped.unwrap() - return - wrapped.sendall(data) - except BrokenPipeError: - pass - - -@contextmanager -def echo_server_connection(): - client_sock, server_sock = socket.socketpair() - with client_sock, server_sock: - t = threading.Thread( - target=_ssl_echo_serve_sync, args=(server_sock,), daemon=True - ) - t.start() - - yield client_sock - - -class ManuallyWrappedSocket: - def __init__(self, ctx, sock, **kwargs): - self.incoming = ssl.MemoryBIO() - self.outgoing = ssl.MemoryBIO() - self.obj = ctx.wrap_bio(self.incoming, self.outgoing, **kwargs) - self.sock = sock - - def _retry(self, fn, *args): - finished = False - while not finished: - want_read = False - try: - ret = fn(*args) - except ssl.SSLWantReadError: - want_read = True - except ssl.SSLWantWriteError: - # can't happen, but if it did this would be the right way to - # handle it anyway - pass - else: - finished = True - # do any sending - data = self.outgoing.read() - if data: - self.sock.sendall(data) - # do any receiving - if want_read: - data = self.sock.recv(BUFSIZE) - if not data: - self.incoming.write_eof() - else: - self.incoming.write(data) - # then retry if necessary - return ret - - def do_handshake(self): - self._retry(self.obj.do_handshake) - - def recv(self, bufsize): - return self._retry(self.obj.read, bufsize) - - def sendall(self, data): - self._retry(self.obj.write, data) - - def unwrap(self): - self._retry(self.obj.unwrap) - return self.sock - - -def wrap_socket_via_wrap_socket(ctx, sock, **kwargs): - return ctx.wrap_socket(sock, do_handshake_on_connect=False, **kwargs) - - -def wrap_socket_via_wrap_bio(ctx, sock, **kwargs): - return ManuallyWrappedSocket(ctx, sock, **kwargs) - - -for wrap_socket in [ - wrap_socket_via_wrap_socket, - wrap_socket_via_wrap_bio, -]: - print(f"\n--- checking {wrap_socket.__name__} ---\n") - - print("checking with do_handshake + correct hostname...") - with echo_server_connection() as client_sock: - client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") - wrapped = wrap_socket( - client_ctx, client_sock, server_hostname="trio-test-1.example.org" - ) - wrapped.do_handshake() - wrapped.sendall(b"x") - assert wrapped.recv(1) == b"x" - wrapped.unwrap() - print("...success") - - print("checking with do_handshake + wrong hostname...") - with echo_server_connection() as client_sock: - client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") - wrapped = wrap_socket( - client_ctx, client_sock, server_hostname="trio-test-2.example.org" - ) - try: - wrapped.do_handshake() - except Exception: - print("...got error as expected") - else: - print("??? no error ???") - - print("checking withOUT do_handshake + wrong hostname...") - with echo_server_connection() as client_sock: - client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") - wrapped = wrap_socket( - client_ctx, client_sock, server_hostname="trio-test-2.example.org" - ) - # We forgot to call do_handshake - # But the hostname is wrong so something had better error out... - sent = b"x" - print("sending", sent) - wrapped.sendall(sent) - got = wrapped.recv(1) - print("got:", got) - assert got == sent - print("!!!! successful chat with invalid host! we have been haxored!") diff --git a/notes-to-self/ssl-handshake/trio-test-1.pem b/notes-to-self/ssl-handshake/trio-test-1.pem deleted file mode 100644 index a0c1b773f9..0000000000 --- a/notes-to-self/ssl-handshake/trio-test-1.pem +++ /dev/null @@ -1,64 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDZ8yz1OrHX7aHp -Erfa1ds8kmYfqYomjgy5wDsGdb8i1gF4uxhHCRDQtNZANVOVXI7R3TMchA1GMxzA -ZYDuBDuEUsqTktbTEBNb4GjOhyMu1fF4dX/tMxf7GB+flTx178eE2exTOZLmSmBa -2laoDVe3CrBAYE7nZtBF630jKKKMsUuIl0CbFRHajpoqM3e3CeCo4KcbBzgujRA3 -AsVV6y5qhMH2zqLkOYaurVUfEkdjqoHFgj1VbjWpkTbrXAxPwW6v/uZK056bHgBg -go03RyWexaPapsF2oUm2JNdSN3z7MP0umKphO2n9icyGt9Bmkm2AKs3dA45VLPXh -+NohluqJAgMBAAECggEARlfWCtAG1ko8F52S+W5MdCBMFawCiq8OLGV+p3cZWYT4 -tJ6uFz81ziaPf+m2MF7POazK8kksf5u/i9k245s6GlseRsL90uE9XknvibjUAinK -5bYGs+fptYDzs+3WtbnOC3LKc5IBd5JJxwjxLwwfY1RvzldHIChu0CJRISfcTsvR -occ8hXdeft7svNymvTuwQd05u1yjzL4RwF8Be76i17j5+jDsrAaUKdxxwGNAyOU7 -OKrUY6G851T6NUGgC19iXAJ1wN9tVGIR5QOs3J/s6dCctnX5tN8Di7prkXCKvVlm -vhpC8XWWG+c3LhS90wmEBvKS0AfUeoPDHxMOLyzKgQKBgQD07lZRO0nsc38+PVaI -NrvlP90Q8OgbwMIC52jmSZK3b5YSh3TrllsbCg6hzUk1SAJsa3qi7B1vq36Fd+rG -LGDRW9xY0cfShLhzqvZWi45zU/RYnEcWHOuXQshLikx1DWUpg2KbLSVT2/lyvzmn -QgM1Te8CSxW5vrBRVfluXoJuEwKBgQDjzLAbwk/wdjITKlQtirtsJEzWi3LGuUrg -Z2kMz+0ztUU5d1oFL9B5xh0CwK8bpK9kYnoVZSy/r5+mGHqyz1eKaDdAXIR13nC0 -g7aZbTZzbt2btvuNZc3NCzRffHF3sCqp8a+oCryHyITjZcA+WYeU8nG0TQ5O8Zgr -Skbo1JGocwKBgQC4jCx1oFqe0pd5afYdREBnB6ul7B63apHEZmBfw+fMV0OYSoAK -Uovq37UOrQMQJmXNE16gC5BSZ8E5B5XaI+3/UVvBgK8zK9VfMd3Sb+yxcPyXF4lo -W/oXSrZoVJgvShyDHv/ZNDb/7KsTjon+QHryWvpPnAuOnON1JXZ/dq6ICQKBgCZF -AukG8esR0EPL/qxP/ECksIvyjWu5QU0F0m4mmFDxiRmoZWUtrTZoBAOsXz6johuZ -N61Ue/oQBSAgSKy1jJ1h+LZFVLOAlSqeXhTUditaWryINyaADdz+nuPTwjQ7Uk+O -nNX8R8P/+eNB+tP+snphaJzDvT2h9NCA//ypiXblAoGAJoLmotPI+P3KIRVzESL0 -DAsVmeijtXE3H+R4nwqUDQbBbFKx0/u2pbON+D5C9llaGiuUp9H+awtwQRYhToeX -CNguwWrcpuhFOCeXDHDWF/0NIZYD2wBMxjF/eUarvoLaT4Gi0yyWh5ExIKOW4bFk -EojUPSJ3gomOUp5bIFcSmSU= ------END PRIVATE KEY----- ------BEGIN CERTIFICATE----- -MIICrzCCAZcCAQEwDQYJKoZIhvcNAQELBQAwFzEVMBMGA1UECgwMVHJpbyB0ZXN0 -IENBMCAXDTE3MDQwOTEwMDcyMVoYDzIyOTEwMTIyMTAwNzIxWjAiMSAwHgYDVQQD -DBd0cmlvLXRlc3QtMS5leGFtcGxlLm9yZzCCASIwDQYJKoZIhvcNAQEBBQADggEP -ADCCAQoCggEBANnzLPU6sdftoekSt9rV2zySZh+piiaODLnAOwZ1vyLWAXi7GEcJ -ENC01kA1U5VcjtHdMxyEDUYzHMBlgO4EO4RSypOS1tMQE1vgaM6HIy7V8Xh1f+0z -F/sYH5+VPHXvx4TZ7FM5kuZKYFraVqgNV7cKsEBgTudm0EXrfSMoooyxS4iXQJsV -EdqOmiozd7cJ4KjgpxsHOC6NEDcCxVXrLmqEwfbOouQ5hq6tVR8SR2OqgcWCPVVu -NamRNutcDE/Bbq/+5krTnpseAGCCjTdHJZ7Fo9qmwXahSbYk11I3fPsw/S6YqmE7 -af2JzIa30GaSbYAqzd0DjlUs9eH42iGW6okCAwEAATANBgkqhkiG9w0BAQsFAAOC -AQEAlRNA96H88lVnzlpQUYt0pwpoy7B3/CDe8Uvl41thKEfTjb+SIo95F4l+fi+l -jISWSonAYXRMNqymPMXl2ir0NigxfvvrcjggER3khASIs0l1ICwTNTv2a40NnFY6 -ZjTaBeSZ/lAi7191AkENDYvMl3aGhb6kALVIbos4/5LvJYF/UXvQfrjriLWZq/I3 -WkvduU9oSi0EA4Jt9aAhblsgDHMBL0+LU8Nl1tgzy2/NePcJWjzBRQDlF8uxCQ+2 -LesZongKQ+lebS4eYbNs0s810h8hrOEcn7VWn7FfxZRkjeaKIst2FCHmdr5JJgxj -8fw+s7l2UkrNURAJ4IRNQvPB+w== ------END CERTIFICATE----- ------BEGIN CERTIFICATE----- -MIIDBjCCAe6gAwIBAgIJAIUF+wna+nuzMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV -BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MDkxMDA3MjFaGA8yMjkxMDEyMjEwMDcy -MVowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC -AQ8AMIIBCgKCAQEAyhE82Cbq+2c2f9M+vj2f9v+z+0+bMZDUVPSXhBDiRdKubt+K -f9vY+ZH3ze1sm0iNgO6xU3OsDTlzO5z0TpsvEEbs0wgsJDUXD7Y8Fb1zH2jaVCro -Y6KcVfFZvD96zsVCnZy0vMsYJw20iIL0RNCtr17lXWVxd17OoVy91djFD9v/cixu -LRIr+N7pa8BDLUQUO/g0ui9YSC9Wgf67mr93KXKPGwjTHBGdjeZeex198j5aZjZR -lkPH/9g5d3hP7EI0EAIMDVd4dvwNJgZzv+AZINbKLAkQyE9AAm+xQ7qBSvdfAvKq -N/fwaFevmyrxUBcfoQxSpds8njWDb3dQzCn7ywIDAQABo1MwUTAdBgNVHQ4EFgQU -JiilveozF8Qpyy2fS3wV4foVRCswHwYDVR0jBBgwFoAUJiilveozF8Qpyy2fS3wV -4foVRCswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAkcthPjkJ -+npvMKyQtUx7CSVcT4Ar0jHrvfPSg17ipyLv+MhSTIbS2VwhSYXxNmu+oBWKuYUs -BnNxly3+SOcs+dTP3GBMng91SBsz5hhbP4ws8uUtCvYJauzeHbeY67R14RT8Ws/b -mP6HDiybN7zy6LOKGCiz+sCoJqVZG/yBYO87iQsTTyNttgoG27yUSvzP07EQwUa5 -F9dI9Wn+4b5wP2ofMCu3asTbKXjfFbz3w5OkRgpGYhC4jhDdOw/819+01R9//GrM -54Gme03yDAAM7nGihr1Xtld3dp2gLuqv0WgxKBqvG5X+nCbr2WamscAP5qz149vo -y6Hq6P4mm2GmZw== ------END CERTIFICATE----- diff --git a/notes-to-self/ssl-handshake/trio-test-CA.pem b/notes-to-self/ssl-handshake/trio-test-CA.pem deleted file mode 100644 index 9bf34001b2..0000000000 --- a/notes-to-self/ssl-handshake/trio-test-CA.pem +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDBjCCAe6gAwIBAgIJAIUF+wna+nuzMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV -BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MDkxMDA3MjFaGA8yMjkxMDEyMjEwMDcy -MVowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC -AQ8AMIIBCgKCAQEAyhE82Cbq+2c2f9M+vj2f9v+z+0+bMZDUVPSXhBDiRdKubt+K -f9vY+ZH3ze1sm0iNgO6xU3OsDTlzO5z0TpsvEEbs0wgsJDUXD7Y8Fb1zH2jaVCro -Y6KcVfFZvD96zsVCnZy0vMsYJw20iIL0RNCtr17lXWVxd17OoVy91djFD9v/cixu -LRIr+N7pa8BDLUQUO/g0ui9YSC9Wgf67mr93KXKPGwjTHBGdjeZeex198j5aZjZR -lkPH/9g5d3hP7EI0EAIMDVd4dvwNJgZzv+AZINbKLAkQyE9AAm+xQ7qBSvdfAvKq -N/fwaFevmyrxUBcfoQxSpds8njWDb3dQzCn7ywIDAQABo1MwUTAdBgNVHQ4EFgQU -JiilveozF8Qpyy2fS3wV4foVRCswHwYDVR0jBBgwFoAUJiilveozF8Qpyy2fS3wV -4foVRCswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAkcthPjkJ -+npvMKyQtUx7CSVcT4Ar0jHrvfPSg17ipyLv+MhSTIbS2VwhSYXxNmu+oBWKuYUs -BnNxly3+SOcs+dTP3GBMng91SBsz5hhbP4ws8uUtCvYJauzeHbeY67R14RT8Ws/b -mP6HDiybN7zy6LOKGCiz+sCoJqVZG/yBYO87iQsTTyNttgoG27yUSvzP07EQwUa5 -F9dI9Wn+4b5wP2ofMCu3asTbKXjfFbz3w5OkRgpGYhC4jhDdOw/819+01R9//GrM -54Gme03yDAAM7nGihr1Xtld3dp2gLuqv0WgxKBqvG5X+nCbr2WamscAP5qz149vo -y6Hq6P4mm2GmZw== ------END CERTIFICATE----- diff --git a/notes-to-self/sslobject.py b/notes-to-self/sslobject.py deleted file mode 100644 index a6e7b07a08..0000000000 --- a/notes-to-self/sslobject.py +++ /dev/null @@ -1,78 +0,0 @@ -import ssl -from contextlib import contextmanager - -client_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) -client_ctx.check_hostname = False -client_ctx.verify_mode = ssl.CERT_NONE - -cinb = ssl.MemoryBIO() -coutb = ssl.MemoryBIO() -cso = client_ctx.wrap_bio(cinb, coutb) - -server_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) -server_ctx.load_cert_chain("server.crt", "server.key", "xxxx") -sinb = ssl.MemoryBIO() -soutb = ssl.MemoryBIO() -sso = server_ctx.wrap_bio(sinb, soutb, server_side=True) - - -@contextmanager -def expect(etype): - try: - yield - except etype: - pass - else: - raise AssertionError(f"expected {etype}") - - -with expect(ssl.SSLWantReadError): - cso.do_handshake() -assert not cinb.pending -assert coutb.pending - -with expect(ssl.SSLWantReadError): - sso.do_handshake() -assert not sinb.pending -assert not soutb.pending - -# A trickle is not enough -# sinb.write(coutb.read(1)) -# with expect(ssl.SSLWantReadError): -# cso.do_handshake() -# with expect(ssl.SSLWantReadError): -# sso.do_handshake() - -sinb.write(coutb.read()) -# Now it should be able to respond -with expect(ssl.SSLWantReadError): - sso.do_handshake() -assert soutb.pending - -cinb.write(soutb.read()) -with expect(ssl.SSLWantReadError): - cso.do_handshake() - -sinb.write(coutb.read()) -# server done! -sso.do_handshake() -assert soutb.pending - -# client done! -cinb.write(soutb.read()) -cso.do_handshake() - -cso.write(b"hello") -sinb.write(coutb.read()) -assert sso.read(10) == b"hello" -with expect(ssl.SSLWantReadError): - sso.read(10) - -# cso.write(b"x" * 2 ** 30) -# print(coutb.pending) - -assert not coutb.pending -assert not cinb.pending -sso.do_handshake() -assert not coutb.pending -assert not cinb.pending diff --git a/notes-to-self/subprocess-notes.txt b/notes-to-self/subprocess-notes.txt deleted file mode 100644 index d3a5c1096c..0000000000 --- a/notes-to-self/subprocess-notes.txt +++ /dev/null @@ -1,73 +0,0 @@ -# subprocesses are a huge hassle -# on Linux there is simply no way to async wait for a child to exit except by -# messing with SIGCHLD and that is ... *such* a mess. Not really -# tenable. We're better off trying os.waitpid(..., os.WNOHANG), and if that -# says the process is still going then spawn a thread to sit in waitpid. -# ......though that waitpid is non-cancellable so ugh. this is a problem, -# because it's also mutating -- you only get to waitpid() once, and you have -# to do it, because zombies. I guess we could make sure the waitpid thread is -# daemonic and either it gets back to us eventually (even if our first call to -# 'await wait()' is cancelled, maybe another one won't be), or else we go away -# and don't care anymore. -# I guess simplest is just to spawn a thread at the same time as we spawn the -# process, with more reasonable notification semantics. -# or we can poll every 100 ms or something, sigh. - -# on Mac/*BSD then kqueue works, go them. (maybe have WNOHANG after turning it -# on to avoid a race condition I guess) - -# on Windows, you can either do the thread thing, or something involving -# WaitForMultipleObjects, or the Job Object API: -# https://stackoverflow.com/questions/17724859/detecting-exit-failure-of-child-processes-using-iocp-c-windows -# (see also the comments here about using the Job Object API: -# https://stackoverflow.com/questions/23434842/python-how-to-kill-child-processes-when-parent-dies/23587108#23587108) -# however the docs say: -# "Note that, with the exception of limits set with the -# JobObjectNotificationLimitInformation information class, delivery of -# messages to the completion port is not guaranteed; failure of a message to -# arrive does not necessarily mean that the event did not occur" -# -# oh windows wtf - -# We'll probably want to mess with the job API anyway for worker processes -# (b/c that's the reliable way to make sure we never leave residual worker -# processes around after exiting, see that stackoverflow question again), so -# maybe this isn't too big a hassle? waitpid is probably easiest for the -# first-pass implementation though. - -# the handle version has the same issues as waitpid on Linux, except I guess -# that on windows the waitpid equivalent doesn't consume the handle. -# -- wait no, the windows equivalent takes a timeout! and we know our -# cancellation deadline going in, so that's actually okay. (Still need to use -# a thread but whatever.) - -# asyncio does RegisterWaitForSingleObject with a callback that does -# PostQueuedCompletionStatus. -# this is just a thread pool in disguise (and in principle could have weird -# problems if you have enough children and run out of threads) -# it's possible we could do something with a thread that just sits in -# an alertable state and handle callbacks...? though hmm, maybe the set of -# events that can notify via callbacks is equivalent to the set that can -# notify via IOCP. -# there's WaitForMultipleObjects to let multiple waits share a thread I -# guess. -# you can wake up a WaitForMultipleObjectsEx on-demand by using QueueUserAPC -# to send a no-op APC to its thread. -# this is also a way to cancel a WaitForSingleObjectEx, actually. So it -# actually is possible to cancel the equivalent of a waitpid on Windows. - -# Potentially useful observation: you *can* use a socket as the -# stdin/stdout/stderr for a child, iff you create that socket *without* -# WSA_FLAG_OVERLAPPED: -# http://stackoverflow.com/a/5725609 -# Here's ncm's Windows implementation of socketpair, which has a flag to -# control whether one of the sockets has WSA_FLAG_OVERLAPPED set: -# https://github.com/ncm/selectable-socketpair/blob/master/socketpair.c -# (it also uses listen(1) so it's robust against someone intercepting things, -# unlike the version in socket.py... not sure anyone really cares, but -# hey. OTOH it only supports AF_INET, while socket.py supports AF_INET6, -# fancy.) -# (or it would be trivial to (re)implement in python, using either -# socket.socketpair or ncm's version as a model, given a cffi function to -# create the non-overlapped socket in the first place then just pass it into -# the socket.socket constructor (avoiding the dup() that fromfd does).) diff --git a/notes-to-self/thread-closure-bug-demo.py b/notes-to-self/thread-closure-bug-demo.py deleted file mode 100644 index b5da68c334..0000000000 --- a/notes-to-self/thread-closure-bug-demo.py +++ /dev/null @@ -1,60 +0,0 @@ -# This is a reproducer for: -# https://bugs.python.org/issue30744 -# https://bitbucket.org/pypy/pypy/issues/2591/ - -import sys -import threading -import time - -COUNT = 100 - - -def slow_tracefunc(frame, event, arg): - # A no-op trace function that sleeps briefly to make us more likely to hit - # the race condition. - time.sleep(0.01) - return slow_tracefunc - - -def run_with_slow_tracefunc(fn): - # settrace() only takes effect when you enter a new frame, so we need this - # little dance: - sys.settrace(slow_tracefunc) - return fn() - - -def outer(): - x = 0 - # We hide the done variable inside a list, because we want to use it to - # communicate between the main thread and the looper thread, and the bug - # here is that variable assignments made in the main thread disappear - # before the child thread can see them... - done = [False] - - def traced_looper(): - # Force w_locals to be instantiated (only matters on PyPy; on CPython - # you can comment this line out and everything stays the same) - print(locals()) - nonlocal x # Force x to be closed over - # Random nonsense whose only purpose is to trigger lots of calls to - # the trace func - count = 0 - while not done[0]: - count += 1 - return count - - t = threading.Thread(target=run_with_slow_tracefunc, args=(traced_looper,)) - t.start() - - for i in range(COUNT): - print(f"after {i} increments, x is {x}") - x += 1 - time.sleep(0.01) - - done[0] = True - t.join() - - print(f"Final discrepancy: {COUNT - x} (should be 0)") - - -outer() diff --git a/notes-to-self/thread-dispatch-bench.py b/notes-to-self/thread-dispatch-bench.py deleted file mode 100644 index 70547a6000..0000000000 --- a/notes-to-self/thread-dispatch-bench.py +++ /dev/null @@ -1,36 +0,0 @@ -# Estimate the cost of simply passing some data into a thread and back, in as -# minimal a fashion as possible. -# -# This is useful to get a sense of the *lower-bound* cost of -# trio.to_thread.run_sync - -import threading -import time -from queue import Queue - -COUNT = 10000 - - -def worker(in_q, out_q): - while True: - job = in_q.get() - out_q.put(job()) - - -def main(): - in_q = Queue() - out_q = Queue() - - t = threading.Thread(target=worker, args=(in_q, out_q)) - t.start() - - while True: - start = time.monotonic() - for _ in range(COUNT): - in_q.put(lambda: None) - out_q.get() - end = time.monotonic() - print(f"{(end - start) / COUNT * 1e6:.2f} ยตs/job") - - -main() diff --git a/notes-to-self/time-wait-windows-exclusiveaddruse.py b/notes-to-self/time-wait-windows-exclusiveaddruse.py deleted file mode 100644 index dcb4a27dd0..0000000000 --- a/notes-to-self/time-wait-windows-exclusiveaddruse.py +++ /dev/null @@ -1,69 +0,0 @@ -# On windows, what does SO_EXCLUSIVEADDRUSE actually do? Apparently not what -# the documentation says! -# See: https://stackoverflow.com/questions/45624916/ -# -# Specifically, this script seems to demonstrate that it only creates -# conflicts between listening sockets, *not* lingering connected sockets. - -import socket -from contextlib import contextmanager - - -@contextmanager -def report_outcome(tagline): - try: - yield - except OSError as exc: - print(f"{tagline}: failed") - print(f" details: {exc!r}") - else: - print(f"{tagline}: succeeded") - - -# Set up initial listening socket -lsock = socket.socket() -lsock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) -lsock.bind(("127.0.0.1", 0)) -sockaddr = lsock.getsockname() -lsock.listen(10) - -# Make connected client and server sockets -csock = socket.socket() -csock.connect(sockaddr) -ssock, _ = lsock.accept() - -print("lsock", lsock.getsockname()) -print("ssock", ssock.getsockname()) - -# Can't make a second listener while the first exists -probe = socket.socket() -probe.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) -with report_outcome("rebind with existing listening socket"): - probe.bind(sockaddr) - -# Now we close the first listen socket, while leaving the connected sockets -# open: -lsock.close() -# This time binding succeeds (contra MSDN!) -probe = socket.socket() -probe.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) -with report_outcome("rebind with live connected sockets"): - probe.bind(sockaddr) - probe.listen(10) - print("probe", probe.getsockname()) - print("ssock", ssock.getsockname()) -probe.close() - -# Server-initiated close to trigger TIME_WAIT status -ssock.send(b"x") -assert csock.recv(1) == b"x" -ssock.close() -assert csock.recv(1) == b"" - -# And does the TIME_WAIT sock prevent binding? -probe = socket.socket() -probe.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) -with report_outcome("rebind with TIME_WAIT socket"): - probe.bind(sockaddr) - probe.listen(10) -probe.close() diff --git a/notes-to-self/time-wait.py b/notes-to-self/time-wait.py deleted file mode 100644 index edc1b39172..0000000000 --- a/notes-to-self/time-wait.py +++ /dev/null @@ -1,113 +0,0 @@ -# what does SO_REUSEADDR do, exactly? - -# Theory: -# -# - listen1 is bound to port P -# - listen1.accept() returns a connected socket server1, which is also bound -# to port P -# - listen1 is closed -# - we attempt to bind listen2 to port P -# - this fails because server1 is still open, or still in TIME_WAIT, and you -# can't use bind() to bind to a port that still has sockets on it, unless -# both those sockets and the socket being bound have SO_REUSEADDR -# -# The standard way to avoid this is to set SO_REUSEADDR on all listening -# sockets before binding them. And this works, but for somewhat more -# complicated reasons than are often appreciated. -# -# In our scenario above it doesn't really matter for listen1 (assuming the -# port is initially unused). -# -# What is important is that it's set on *server1*. Setting it on listen1 -# before calling bind() automatically accomplishes this, because SO_REUSEADDR -# is inherited by accept()ed sockets. But it also works to set it on listen1 -# any time before calling accept(), or to set it on server1 directly. -# -# Also, it must be set on listen2 before calling bind(), or it will conflict -# with the lingering server1 socket. - -import errno -import socket - -import attrs - - -@attrs.define(repr=False, slots=False) -class Options: - listen1_early = None - listen1_middle = None - listen1_late = None - server = None - listen2 = None - - def set(self, which, sock): - value = getattr(self, which) - if value is not None: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, value) - - def describe(self): - info = [] - for f in attrs.fields(self.__class__): - value = getattr(self, f.name) - if value is not None: - info.append(f"{f.name}={value}") - return "Set/unset: {}".format(", ".join(info)) - - -def time_wait(options): - print(options.describe()) - - # Find a pristine port (one we can definitely bind to without - # SO_REUSEADDR) - listen0 = socket.socket() - listen0.bind(("127.0.0.1", 0)) - sockaddr = listen0.getsockname() - # print(" ", sockaddr) - listen0.close() - - listen1 = socket.socket() - options.set("listen1_early", listen1) - listen1.bind(sockaddr) - listen1.listen(1) - - options.set("listen1_middle", listen1) - - client = socket.socket() - client.connect(sockaddr) - - options.set("listen1_late", listen1) - - server, _ = listen1.accept() - - options.set("server", server) - - # Server initiated close to trigger TIME_WAIT status - server.close() - assert client.recv(10) == b"" - client.close() - - listen1.close() - - listen2 = socket.socket() - options.set("listen2", listen2) - try: - listen2.bind(sockaddr) - except OSError as exc: - if exc.errno == errno.EADDRINUSE: - print(" -> EADDRINUSE") - else: - raise - else: - print(" -> ok") - - -time_wait(Options()) -time_wait(Options(listen1_early=True, server=True, listen2=True)) -time_wait(Options(listen1_early=True)) -time_wait(Options(server=True)) -time_wait(Options(listen2=True)) -time_wait(Options(listen1_early=True, listen2=True)) -time_wait(Options(server=True, listen2=True)) -time_wait(Options(listen1_middle=True, listen2=True)) -time_wait(Options(listen1_late=True, listen2=True)) -time_wait(Options(listen1_middle=True, server=False, listen2=True)) diff --git a/notes-to-self/trace.py b/notes-to-self/trace.py deleted file mode 100644 index 046412d3ae..0000000000 --- a/notes-to-self/trace.py +++ /dev/null @@ -1,154 +0,0 @@ -import json -import os -from itertools import count - -import trio - -# Experiment with generating Chrome Event Trace format, which can be browsed -# through chrome://tracing or other mechanisms. -# -# Screenshot: https://files.gitter.im/python-trio/general/fp6w/image.png -# -# Trace format docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview# -# -# Things learned so far: -# - I don't understand how the ph="s"/ph="f" flow events work โ€“ I think -# they're supposed to show up as arrows, and I'm emitting them between tasks -# that wake each other up, but they're not showing up. -# - I think writing out json synchronously from each event is creating gaps in -# the trace; maybe better to batch them up to write up all at once at the -# end -# - including tracebacks would be cool -# - there doesn't seem to be any good way to group together tasks based on -# nurseries. this really limits the value of this particular trace -# format+viewer for us. (also maybe we should have an instrumentation event -# when a nursery is opened/closed?) -# - task._counter should maybe be public -# - I don't know how to best show task lifetime, scheduling times, and what -# the task is actually doing on the same plot. if we want to show particular -# events like "called stream.send_all", then the chrome trace format won't -# let us also show "task is running", because neither kind of event is -# strictly nested inside the other - - -class Trace(trio.abc.Instrument): - def __init__(self, out): - self.out = out - self.out.write("[\n") - self.ids = count() - self._task_metadata(-1, "I/O manager") - - def _write(self, **ev): - ev.setdefault("pid", os.getpid()) - if ev["ph"] != "M": - ev.setdefault("ts", trio.current_time() * 1e6) - self.out.write(json.dumps(ev)) - self.out.write(",\n") - - def _task_metadata(self, tid, name): - self._write( - name="thread_name", - ph="M", - tid=tid, - args={"name": name}, - ) - self._write( - name="thread_sort_index", - ph="M", - tid=tid, - args={"sort_index": tid}, - ) - - def task_spawned(self, task): - self._task_metadata(task._counter, task.name) - self._write( - name="task lifetime", - ph="B", - tid=task._counter, - ) - - def task_exited(self, task): - self._write( - name="task lifetime", - ph="E", - tid=task._counter, - ) - - def before_task_step(self, task): - self._write( - name="running", - ph="B", - tid=task._counter, - ) - - def after_task_step(self, task): - self._write( - name="running", - ph="E", - tid=task._counter, - ) - - def task_scheduled(self, task): - try: - waker = trio.lowlevel.current_task() - except RuntimeError: - pass - else: - id_ = next(self.ids) - self._write( - ph="s", - cat="wakeup", - id=id_, - tid=waker._counter, - ) - self._write( - cat="wakeup", - ph="f", - id=id_, - tid=task._counter, - ) - - def before_io_wait(self, timeout): - self._write( - name="I/O wait", - ph="B", - tid=-1, - ) - - def after_io_wait(self, timeout): - self._write( - name="I/O wait", - ph="E", - tid=-1, - ) - - -async def child1(): - print(" child1: started! sleeping now...") - await trio.sleep(1) - print(" child1: exiting!") - - -async def child2(): - print(" child2: started! sleeping now...") - await trio.sleep(1) - print(" child2: exiting!") - - -async def parent(): - print("parent: started!") - async with trio.open_nursery() as nursery: - print("parent: spawning child1...") - nursery.start_soon(child1) - - print("parent: spawning child2...") - nursery.start_soon(child2) - - print("parent: waiting for children to finish...") - # -- we exit the nursery block here -- - print("parent: all done!") - - -with open("/tmp/t.json", "w") as t_json: - t = Trace(t_json) - trio.run(parent, instruments=[t]) diff --git a/notes-to-self/trivial-err.py b/notes-to-self/trivial-err.py deleted file mode 100644 index 6c32617c74..0000000000 --- a/notes-to-self/trivial-err.py +++ /dev/null @@ -1,33 +0,0 @@ -import sys - -import trio - -sys.stderr = sys.stdout - - -async def child1(): - raise ValueError - - -async def child2(): - async with trio.open_nursery() as nursery: - nursery.start_soon(grandchild1) - nursery.start_soon(grandchild2) - - -async def grandchild1(): - raise KeyError - - -async def grandchild2(): - raise NameError("Bob") - - -async def main(): - async with trio.open_nursery() as nursery: - nursery.start_soon(child1) - nursery.start_soon(child2) - # nursery.start_soon(grandchild1) - - -trio.run(main) diff --git a/notes-to-self/trivial.py b/notes-to-self/trivial.py deleted file mode 100644 index 405d92daf5..0000000000 --- a/notes-to-self/trivial.py +++ /dev/null @@ -1,10 +0,0 @@ -import trio - - -async def foo(): - print("in foo!") - return 3 - - -print("running!") -print(trio.run(foo)) diff --git a/notes-to-self/wakeup-fd-racer.py b/notes-to-self/wakeup-fd-racer.py deleted file mode 100644 index b56cbdc91c..0000000000 --- a/notes-to-self/wakeup-fd-racer.py +++ /dev/null @@ -1,105 +0,0 @@ -import itertools -import os -import select -import signal -import socket -import threading -import time - -# Equivalent to the C function raise(), which Python doesn't wrap -if os.name == "nt": - import cffi - - _ffi = cffi.FFI() - _ffi.cdef("int raise(int);") - _lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll") - signal_raise = getattr(_lib, "raise") -else: - - def signal_raise(signum): - # Use pthread_kill to make sure we're actually using the wakeup fd on - # Unix - signal.pthread_kill(threading.get_ident(), signum) - - -def raise_SIGINT_soon(): - time.sleep(1) - signal_raise(signal.SIGINT) - # Sending 2 signals becomes reliable, as we'd expect (because we need - # set-flags -> write-to-fd, and doing it twice does - # write-to-fd -> set-flags -> write-to-fd -> set-flags) - # signal_raise(signal.SIGINT) - - -def drain(sock): - total = 0 - try: - while True: - total += len(sock.recv(1024)) - except BlockingIOError: - pass - return total - - -def main(): - writer, reader = socket.socketpair() - writer.setblocking(False) - reader.setblocking(False) - - signal.set_wakeup_fd(writer.fileno()) - - # Keep trying until we lose the race... - for attempt in itertools.count(): - print(f"Attempt {attempt}: start") - - # Make sure the socket is empty - drained = drain(reader) - if drained: - print(f"Attempt {attempt}: ({drained} residual bytes discarded)") - - # Arrange for SIGINT to be delivered 1 second from now - thread = threading.Thread(target=raise_SIGINT_soon) - thread.start() - - # Fake an IO loop that's trying to sleep for 10 seconds (but will - # hopefully get interrupted after just 1 second) - start = time.perf_counter() - target = start + 10 - try: - select_calls = 0 - drained = 0 - while True: - now = time.perf_counter() - if now > target: - break - select_calls += 1 - r, _, _ = select.select([reader], [], [], target - now) - if r: - # In theory we should loop to fully drain the socket but - # honestly there's 1 byte in there at most and it'll be - # fine. - drained += drain(reader) - except KeyboardInterrupt: - pass - else: - print(f"Attempt {attempt}: no KeyboardInterrupt?!") - - # We expect a successful run to take 1 second, and a failed run to - # take 10 seconds, so 2 seconds is a reasonable cutoff to distinguish - # them. - duration = time.perf_counter() - start - if duration < 2: - print( - f"Attempt {attempt}: OK, trying again " - f"(select_calls = {select_calls}, drained = {drained})" - ) - else: - print(f"Attempt {attempt}: FAILED, took {duration} seconds") - print(f"select_calls = {select_calls}, drained = {drained}") - break - - thread.join() - - -if __name__ == "__main__": - main() diff --git a/notes-to-self/win-waitable-timer.py b/notes-to-self/win-waitable-timer.py deleted file mode 100644 index 43e6acd34a..0000000000 --- a/notes-to-self/win-waitable-timer.py +++ /dev/null @@ -1,201 +0,0 @@ -# Sandbox for exploring the Windows "waitable timer" API. -# Cf https://github.com/python-trio/trio/issues/173 -# -# Observations: -# - if you set a timer in the far future, then block in -# WaitForMultipleObjects, then set the computer's clock forward by a few -# years (past the target sleep time), then the timer immediately wakes up -# (which is good!) -# - if you set a timer in the past, then it wakes up immediately - -# Random thoughts: -# - top-level API sleep_until_datetime -# - portable manages the heap of outstanding sleeps, runs a system task to -# wait for the next one, wakes up tasks when their deadline arrives, etc. -# - non-portable code: async def sleep_until_datetime_raw, which simply blocks -# until the given time using system-specific methods. Can assume that there -# is only one call to this method at a time. -# Actually, this should be a method, so it can hold persistent state (e.g. -# timerfd). -# Can assume that the datetime passed in has tzinfo=timezone.utc -# Need a way to override this object for testing. -# -# should we expose wake-system-on-alarm functionality? windows and linux both -# make this fairly straightforward, but you obviously need to use a separate -# time source - -import contextlib -from datetime import datetime, timedelta, timezone - -import cffi -import trio -from trio._core._windows_cffi import ffi, kernel32, raise_winerror - -with contextlib.suppress(cffi.CDefError): - ffi.cdef( - """ -typedef struct _PROCESS_LEAP_SECOND_INFO { - ULONG Flags; - ULONG Reserved; -} PROCESS_LEAP_SECOND_INFO, *PPROCESS_LEAP_SECOND_INFO; - -typedef struct _SYSTEMTIME { - WORD wYear; - WORD wMonth; - WORD wDayOfWeek; - WORD wDay; - WORD wHour; - WORD wMinute; - WORD wSecond; - WORD wMilliseconds; -} SYSTEMTIME, *PSYSTEMTIME, *LPSYSTEMTIME; -""" - ) - -ffi.cdef( - """ -typedef LARGE_INTEGER FILETIME; -typedef FILETIME* LPFILETIME; - -HANDLE CreateWaitableTimerW( - LPSECURITY_ATTRIBUTES lpTimerAttributes, - BOOL bManualReset, - LPCWSTR lpTimerName -); - -BOOL SetWaitableTimer( - HANDLE hTimer, - const LPFILETIME lpDueTime, - LONG lPeriod, - void* pfnCompletionRoutine, - LPVOID lpArgToCompletionRoutine, - BOOL fResume -); - -BOOL SetProcessInformation( - HANDLE hProcess, - /* Really an enum, PROCESS_INFORMATION_CLASS */ - int32_t ProcessInformationClass, - LPVOID ProcessInformation, - DWORD ProcessInformationSize -); - -void GetSystemTimeAsFileTime( - LPFILETIME lpSystemTimeAsFileTime -); - -BOOL SystemTimeToFileTime( - const SYSTEMTIME *lpSystemTime, - LPFILETIME lpFileTime -); -""", - override=True, -) - -ProcessLeapSecondInfo = 8 -PROCESS_LEAP_SECOND_INFO_FLAG_ENABLE_SIXTY_SECOND = 1 - - -def set_leap_seconds_enabled(enabled): - plsi = ffi.new("PROCESS_LEAP_SECOND_INFO*") - if enabled: - plsi.Flags = PROCESS_LEAP_SECOND_INFO_FLAG_ENABLE_SIXTY_SECOND - else: - plsi.Flags = 0 - plsi.Reserved = 0 - if not kernel32.SetProcessInformation( - ffi.cast("HANDLE", -1), # current process - ProcessLeapSecondInfo, - plsi, - ffi.sizeof("PROCESS_LEAP_SECOND_INFO"), - ): - raise_winerror() - - -def now_as_filetime(): - ft = ffi.new("LARGE_INTEGER*") - kernel32.GetSystemTimeAsFileTime(ft) - return ft[0] - - -# "FILETIME" is a specific Windows time representation, that I guess was used -# for files originally but now gets used in all kinds of non-file-related -# places. Essentially: integer count of "ticks" since an epoch in 1601, where -# each tick is 100 nanoseconds, in UTC but pretending that leap seconds don't -# exist. (Fortunately, the Python datetime module also pretends that -# leapseconds don't exist, so we can use datetime arithmetic to compute -# FILETIME values.) -# -# https://docs.microsoft.com/en-us/windows/win32/sysinfo/file-times -# -# This page has FILETIME converters and can be useful for debugging: -# -# https://www.epochconverter.com/ldap -# -FILETIME_TICKS_PER_SECOND = 10**7 -FILETIME_EPOCH = datetime.strptime("1601-01-01 00:00:00 Z", "%Y-%m-%d %H:%M:%S %z") -# XXX THE ABOVE IS WRONG: -# -# https://techcommunity.microsoft.com/t5/networking-blog/leap-seconds-for-the-appdev-what-you-should-know/ba-p/339813# -# -# Sometimes Windows FILETIME does include leap seconds! It depends on Windows -# version, process-global state, environment state, registry settings, and who -# knows what else! -# -# So actually the only correct way to convert a YMDhms-style representation of -# a time into a FILETIME is to use SystemTimeToFileTime -# -# ...also I can't even run this test on my VM, because it's running an ancient -# version of Win10 that doesn't have leap second support. Also also, Windows -# only tracks leap seconds since they added leap second support, and there -# haven't been any, so right now things work correctly either way. -# -# It is possible to insert some fake leap seconds for testing, if you want. - - -def py_datetime_to_win_filetime(dt): - # We'll want to call this on every datetime as it comes in - # dt = dt.astimezone(timezone.utc) - assert dt.tzinfo is timezone.utc - return round((dt - FILETIME_EPOCH).total_seconds() * FILETIME_TICKS_PER_SECOND) - - -async def main(): - h = kernel32.CreateWaitableTimerW(ffi.NULL, True, ffi.NULL) - if not h: - raise_winerror() - print(h) - - SECONDS = 2 - - wakeup = datetime.now(timezone.utc) + timedelta(seconds=SECONDS) - wakeup_filetime = py_datetime_to_win_filetime(wakeup) - wakeup_cffi = ffi.new("LARGE_INTEGER *") - wakeup_cffi[0] = wakeup_filetime - - print(wakeup_filetime, wakeup_cffi) - - print(f"Sleeping for {SECONDS} seconds (until {wakeup})") - - if not kernel32.SetWaitableTimer( - h, - wakeup_cffi, - 0, - ffi.NULL, - ffi.NULL, - False, - ): - raise_winerror() - - await trio.hazmat.WaitForSingleObject(h) - - print(f"Current FILETIME: {now_as_filetime()}") - set_leap_seconds_enabled(False) - print(f"Current FILETIME: {now_as_filetime()}") - set_leap_seconds_enabled(True) - print(f"Current FILETIME: {now_as_filetime()}") - set_leap_seconds_enabled(False) - print(f"Current FILETIME: {now_as_filetime()}") - - -trio.run(main) diff --git a/notes-to-self/windows-vm-notes.txt b/notes-to-self/windows-vm-notes.txt deleted file mode 100644 index 804069f108..0000000000 --- a/notes-to-self/windows-vm-notes.txt +++ /dev/null @@ -1,16 +0,0 @@ -Have a VM in virtualbox - -activate winenv here, or use py, py -m pip, etc.; regular python is -not in the path - -virtualbox is set to map my home dir to \\vboxsrv\njs, which can be -mapped to a drive with: net use x: \\vboxsrv\njs - -if switching back and forth between windows and linux in the same -directory and using the same version of python, .pyc files are a problem. - find -name __pycache__ | xargs rm -rf -export PYTHONDONTWRITEBYTECODE=1 - -if things freeze, control-C doesn't seem reliable... possibly this is -a bug in my code :-(. but can get to task manager via vbox menu Input -Keyboard -> Insert ctrl-alt-del. diff --git a/pyproject.toml b/pyproject.toml index 0e26fea83a..bbc865ac6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -35,12 +34,12 @@ classifiers = [ "Topic :: System :: Networking", "Typing :: Typed", ] -requires-python = ">=3.8" +requires-python = ">=3.9" dependencies = [ # attrs 19.2.0 adds `eq` option to decorators # attrs 20.1.0 adds @frozen # attrs 21.1.0 adds a dataclass transform for type-checkers - # attrs 21.3.0 adds `import addrs` + # attrs 21.3.0 adds `import attrs` "attrs >= 23.2.0", "sortedcontainers", "idna", @@ -75,7 +74,6 @@ include-package-data = true version = {attr = "trio._version.__version__"} [tool.black] -target-version = ['py38'] force-exclude = ''' ( ^/docs/source/reference-.* @@ -92,7 +90,7 @@ fix = true # The directories to consider when resolving first vs. third-party imports. # Does not control what files to include/exclude! -src = ["src/trio", "notes-to-self"] +src = ["src/trio"] include = ["*.py", "*.pyi", "**/pyproject.toml"] @@ -102,36 +100,49 @@ extend-exclude = [ ] [tool.ruff.lint] +preview = true allowed-confusables = ["โ€“"] select = [ "A", # flake8-builtins + "ANN", # flake8-annotations "ASYNC", # flake8-async "B", # flake8-bugbear "C4", # flake8-comprehensions + "COM", # flake8-commas "E", # Error + "EXE", # flake8-executable "F", # pyflakes "FA", # flake8-future-annotations + "FLY", # flynt + "FURB", # refurb "I", # isort + "ICN", # flake8-import-conventions "PERF", # Perflint + "PIE", # flake8-pie "PT", # flake8-pytest-style "PYI", # flake8-pyi + "Q", # flake8-quotes "RUF", # Ruff-specific rules "SIM", # flake8-simplify - "TCH", # flake8-type-checking + "TC", # flake8-type-checking "UP", # pyupgrade "W", # Warning "YTT", # flake8-2020 ] extend-ignore = [ - 'A002', # builtin-argument-shadowing - 'E402', # module-import-not-at-top-of-file (usually OS-specific) - 'E501', # line-too-long - 'F403', # undefined-local-with-import-star - 'F405', # undefined-local-with-import-star-usage - 'PERF203', # try-except-in-loop (not always possible to refactor) - 'PT012', # multiple statements in pytest.raises block - 'SIM117', # multiple-with-statements (messes up lots of context-based stuff and looks bad) + "A002", # builtin-argument-shadowing + "ANN401", # any-type (mypy's `disallow_any_explicit` is better) + "E402", # module-import-not-at-top-of-file (usually OS-specific) + "E501", # line-too-long + "F403", # undefined-local-with-import-star + "F405", # undefined-local-with-import-star-usage + "PERF203", # try-except-in-loop (not always possible to refactor) + "PT012", # multiple statements in pytest.raises block + "SIM117", # multiple-with-statements (messes up lots of context-based stuff and looks bad) + + # conflicts with formatter (ruff recommends these be disabled) + "COM812", ] [tool.ruff.lint.per-file-ignores] @@ -140,10 +151,20 @@ extend-ignore = [ # to export for public use. 'src/trio/__init__.py' = ['F401'] 'src/trio/_core/__init__.py' = ['F401'] -'src/trio/abc.py' = ['F401'] +'src/trio/abc.py' = ['F401', 'A005'] 'src/trio/lowlevel.py' = ['F401'] -'src/trio/socket.py' = ['F401'] +'src/trio/socket.py' = ['F401', 'A005'] 'src/trio/testing/__init__.py' = ['F401'] +# RUF029 is ignoring tests that are marked as async functions but +# do not use an await in their function bodies. There are several +# places where internal trio synchronous code relies on being +# called from an async function, where current task is set up. +'src/trio/_tests/*.py' = ['RUF029'] +'src/trio/_core/_tests/*.py' = ['RUF029'] +# A005 is ignoring modules that shadow stdlib modules. +'src/trio/_abc.py' = ['A005'] +'src/trio/_socket.py' = ['A005'] +'src/trio/_ssl.py' = ['A005'] [tool.ruff.lint.isort] combine-as-imports = true @@ -152,7 +173,7 @@ combine-as-imports = true fixture-parentheses = false [tool.mypy] -python_version = "3.8" +python_version = "3.9" files = ["src/trio/", "docs/source/*.py"] # Be flexible about dependencies that don't have stubs yet (like pytest) @@ -167,6 +188,7 @@ warn_return_any = true # Avoid subtle backsliding disallow_any_decorated = true +disallow_any_explicit = true disallow_any_generics = true disallow_any_unimported = true disallow_incomplete_defs = true @@ -177,12 +199,12 @@ disallow_untyped_defs = true check_untyped_defs = true [tool.pyright] -pythonVersion = "3.8" +pythonVersion = "3.9" reportUnnecessaryTypeIgnoreComment = true typeCheckingMode = "strict" [tool.pytest.ini_options] -addopts = ["--strict-markers", "--strict-config", "-p trio._tests.pytest_plugin"] +addopts = ["--strict-markers", "--strict-config", "-p trio._tests.pytest_plugin", "--import-mode=importlib"] faulthandler_timeout = 60 filterwarnings = [ "error", @@ -249,6 +271,20 @@ directory = "misc" name = "Miscellaneous internal changes" showcontent = true +[tool.coverage.html] +show_contexts = true +skip_covered = false + +[tool.coverage.paths] +_site-packages-to-src-mapping = [ + "src", + "*/src", + '*\src', + "*/lib/pypy*/site-packages", + "*/lib/python*/site-packages", + '*\Lib\site-packages', +] + [tool.coverage.run] branch = true source_pkgs = ["trio"] @@ -263,19 +299,25 @@ omit = [ # The test suite spawns subprocesses to test some stuff, so make sure # this doesn't corrupt the coverage files parallel = true +plugins = [] +relative_files = true +source = ["."] [tool.coverage.report] precision = 1 skip_covered = true -exclude_lines = [ - "pragma: no cover", - "abc.abstractmethod", - "if TYPE_CHECKING.*:", - "if _t.TYPE_CHECKING:", - "if t.TYPE_CHECKING:", - "@overload", - 'class .*\bProtocol\b.*\):', - "raise NotImplementedError", +skip_empty = true +show_missing = true +exclude_also = [ + '^\s*@pytest\.mark\.xfail', + "abc.abstractmethod", + "if TYPE_CHECKING.*:", + "if _t.TYPE_CHECKING:", + "if t.TYPE_CHECKING:", + "@overload", + 'class .*\bProtocol\b.*\):', + "raise NotImplementedError", + 'TODO: test this line' ] partial_branches = [ "pragma: no branch", @@ -285,4 +327,5 @@ partial_branches = [ "if .* or not TYPE_CHECKING:", "if .* or not _t.TYPE_CHECKING:", "if .* or not t.TYPE_CHECKING:", + 'TODO: test this branch', ] diff --git a/src/trio/__init__.py b/src/trio/__init__.py index d2151677b1..34fda84525 100644 --- a/src/trio/__init__.py +++ b/src/trio/__init__.py @@ -25,6 +25,7 @@ # Submodules imported by default from . import abc, from_thread, lowlevel, socket, to_thread from ._channel import ( + MemoryChannelStatistics as MemoryChannelStatistics, MemoryReceiveChannel as MemoryReceiveChannel, MemorySendChannel as MemorySendChannel, open_memory_channel as open_memory_channel, diff --git a/src/trio/_abc.py b/src/trio/_abc.py index 20f1614cc6..306ee227fc 100644 --- a/src/trio/_abc.py +++ b/src/trio/_abc.py @@ -198,7 +198,9 @@ async def getaddrinfo( @abstractmethod async def getnameinfo( - self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + self, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, ) -> tuple[str, str]: """A custom implementation of :func:`~trio.socket.getnameinfo`. @@ -691,6 +693,14 @@ async def __anext__(self) -> ReceiveType: raise StopAsyncIteration from None +# these are necessary for Sphinx's :show-inheritance: with type args. +# (this should be removed if possible) +# see: https://github.com/python/cpython/issues/123250 +SendChannel.__module__ = SendChannel.__module__.replace("_abc", "abc") +ReceiveChannel.__module__ = ReceiveChannel.__module__.replace("_abc", "abc") +Listener.__module__ = Listener.__module__.replace("_abc", "abc") + + class Channel(SendChannel[T], ReceiveChannel[T]): """A standard interface for interacting with bidirectional channels. @@ -700,3 +710,7 @@ class Channel(SendChannel[T], ReceiveChannel[T]): """ __slots__ = () + + +# see above +Channel.__module__ = Channel.__module__.replace("_abc", "abc") diff --git a/src/trio/_channel.py b/src/trio/_channel.py index f5ed4004d7..6410d9120c 100644 --- a/src/trio/_channel.py +++ b/src/trio/_channel.py @@ -5,7 +5,6 @@ from typing import ( TYPE_CHECKING, Generic, - Tuple, # only needed for typechecking on <3.9 ) import attrs @@ -93,14 +92,14 @@ def _open_memory_channel( # it could replace the normal function header if TYPE_CHECKING: # written as a class so you can say open_memory_channel[int](5) - # Need to use Tuple instead of tuple due to CI check running on 3.8 - class open_memory_channel(Tuple["MemorySendChannel[T]", "MemoryReceiveChannel[T]"]): + class open_memory_channel(tuple["MemorySendChannel[T]", "MemoryReceiveChannel[T]"]): def __new__( # type: ignore[misc] # "must return a subtype" - cls, max_buffer_size: int | float # noqa: PYI041 + cls, + max_buffer_size: int | float, # noqa: PYI041 ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: return _open_memory_channel(max_buffer_size) - def __init__(self, max_buffer_size: int | float): # noqa: PYI041 + def __init__(self, max_buffer_size: int | float) -> None: # noqa: PYI041 ... else: @@ -110,7 +109,7 @@ def __init__(self, max_buffer_size: int | float): # noqa: PYI041 @attrs.frozen -class MemoryChannelStats: +class MemoryChannelStatistics: current_buffer_used: int max_buffer_size: int | float open_send_channels: int @@ -131,8 +130,8 @@ class MemoryChannelState(Generic[T]): # {task: None} receive_tasks: OrderedDict[Task, None] = attrs.Factory(OrderedDict) - def statistics(self) -> MemoryChannelStats: - return MemoryChannelStats( + def statistics(self) -> MemoryChannelStatistics: + return MemoryChannelStatistics( current_buffer_used=len(self.data), max_buffer_size=self.max_buffer_size, open_send_channels=self.open_send_channels, @@ -158,7 +157,9 @@ def __attrs_post_init__(self) -> None: def __repr__(self) -> str: return f"" - def statistics(self) -> MemoryChannelStats: + def statistics(self) -> MemoryChannelStatistics: + """Returns a `MemoryChannelStatistics` for the memory channel this is + associated with.""" # XX should we also report statistics specific to this object? return self._state.statistics() @@ -281,6 +282,9 @@ def close(self) -> None: @enable_ki_protection async def aclose(self) -> None: + """Close this send channel object asynchronously. + + See `MemorySendChannel.close`.""" self.close() await trio.lowlevel.checkpoint() @@ -295,7 +299,9 @@ class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstr def __attrs_post_init__(self) -> None: self._state.open_receive_channels += 1 - def statistics(self) -> MemoryChannelStats: + def statistics(self) -> MemoryChannelStatistics: + """Returns a `MemoryChannelStatistics` for the memory channel this is + associated with.""" return self._state.statistics() def __repr__(self) -> str: @@ -429,5 +435,8 @@ def close(self) -> None: @enable_ki_protection async def aclose(self) -> None: + """Close this receive channel object asynchronously. + + See `MemoryReceiveChannel.close`.""" self.close() await trio.lowlevel.checkpoint() diff --git a/src/trio/_core/__init__.py b/src/trio/_core/__init__.py index 3cec8a52e0..4aa096fd0b 100644 --- a/src/trio/_core/__init__.py +++ b/src/trio/_core/__init__.py @@ -20,7 +20,12 @@ from ._ki import currently_ki_protected, disable_ki_protection, enable_ki_protection from ._local import RunVar, RunVarToken from ._mock_clock import MockClock -from ._parking_lot import ParkingLot, ParkingLotStatistics +from ._parking_lot import ( + ParkingLot, + ParkingLotStatistics, + add_parking_lot_breaker, + remove_parking_lot_breaker, +) # Imports that always exist from ._run import ( diff --git a/src/trio/_core/_asyncgens.py b/src/trio/_core/_asyncgens.py index 1a622dadfc..b3b6895753 100644 --- a/src/trio/_core/_asyncgens.py +++ b/src/trio/_core/_asyncgens.py @@ -4,7 +4,7 @@ import sys import warnings import weakref -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, NoReturn, TypeVar import attrs @@ -16,15 +16,31 @@ ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors") if TYPE_CHECKING: + from collections.abc import Callable from types import AsyncGeneratorType - from typing import Set + + from typing_extensions import ParamSpec + + _P = ParamSpec("_P") _WEAK_ASYNC_GEN_SET = weakref.WeakSet[AsyncGeneratorType[object, NoReturn]] - _ASYNC_GEN_SET = Set[AsyncGeneratorType[object, NoReturn]] + _ASYNC_GEN_SET = set[AsyncGeneratorType[object, NoReturn]] else: _WEAK_ASYNC_GEN_SET = weakref.WeakSet _ASYNC_GEN_SET = set +_R = TypeVar("_R") + + +@_core.disable_ki_protection +def _call_without_ki_protection( + f: Callable[_P, _R], + /, + *args: _P.args, + **kwargs: _P.kwargs, +) -> _R: + return f(*args, **kwargs) + @attrs.define(eq=False) class AsyncGenerators: @@ -36,6 +52,11 @@ class AsyncGenerators: # regular set so we don't have to deal with GC firing at # unexpected times. alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attrs.Factory(_WEAK_ASYNC_GEN_SET) + # The ids of foreign async generators are added to this set when first + # iterated. Usually it is not safe to refer to ids like this, but because + # we're using a finalizer we can ensure ids in this set do not outlive + # their async generator. + foreign: set[int] = attrs.Factory(set) # This collects async generators that get garbage collected during # the one-tick window between the system nursery closing and the @@ -52,15 +73,16 @@ def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None: # An async generator first iterated outside of a Trio # task doesn't belong to Trio. Probably we're in guest # mode and the async generator belongs to our host. - # The locals dictionary is the only good place to + # A strong set of ids is one of the only good places to # remember this fact, at least until - # https://bugs.python.org/issue40916 is implemented. - agen.ag_frame.f_locals["@trio_foreign_asyncgen"] = True + # https://github.com/python/cpython/issues/85093 is implemented. + self.foreign.add(id(agen)) if self.prev_hooks.firstiter is not None: self.prev_hooks.firstiter(agen) def finalize_in_trio_context( - agen: AsyncGeneratorType[object, NoReturn], agen_name: str + agen: AsyncGeneratorType[object, NoReturn], + agen_name: str, ) -> None: try: runner.spawn_system_task( @@ -76,16 +98,21 @@ def finalize_in_trio_context( # have hit it. self.trailing_needs_finalize.add(agen) + @_core.enable_ki_protection def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: - agen_name = name_asyncgen(agen) try: - is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen") - except AttributeError: # pragma: no cover + self.foreign.remove(id(agen)) + except KeyError: is_ours = True + else: + is_ours = False + agen_name = name_asyncgen(agen) if is_ours: runner.entry_queue.run_sync_soon( - finalize_in_trio_context, agen, agen_name + finalize_in_trio_context, + agen, + agen_name, ) # Do this last, because it might raise an exception @@ -103,8 +130,9 @@ def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: ) else: # Not ours -> forward to the host loop's async generator finalizer - if self.prev_hooks.finalizer is not None: - self.prev_hooks.finalizer(agen) + finalizer = self.prev_hooks.finalizer + if finalizer is not None: + _call_without_ki_protection(finalizer, agen) else: # Host has no finalizer. Reimplement the default # Python behavior with no hooks installed: throw in @@ -114,7 +142,7 @@ def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: try: # If the next thing is a yield, this will raise RuntimeError # which we allow to propagate - closer.send(None) + _call_without_ki_protection(closer.send, None) except StopIteration: pass else: @@ -123,7 +151,7 @@ def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: raise RuntimeError( f"Non-Trio async generator {agen_name!r} awaited something " "during finalization; install a finalization hook to " - "support this, or wrap it in 'async with aclosing(...):'" + "support this, or wrap it in 'async with aclosing(...):'", ) self.prev_hooks = sys.get_asyncgen_hooks() @@ -146,7 +174,7 @@ async def finalize_remaining(self, runner: _run.Runner) -> None: # them was an asyncgen finalizer that snuck in under the wire. runner.entry_queue.run_sync_soon(runner.reschedule, runner.init_task) await _core.wait_task_rescheduled( - lambda _: _core.Abort.FAILED # pragma: no cover + lambda _: _core.Abort.FAILED, # pragma: no cover ) self.alive.update(self.trailing_needs_finalize) self.trailing_needs_finalize.clear() @@ -193,7 +221,9 @@ def close(self) -> None: sys.set_asyncgen_hooks(*self.prev_hooks) async def _finalize_one( - self, agen: AsyncGeneratorType[object, NoReturn], name: object + self, + agen: AsyncGeneratorType[object, NoReturn], + name: object, ) -> None: try: # This shield ensures that finalize_asyncgen never exits diff --git a/src/trio/_core/_concat_tb.py b/src/trio/_core/_concat_tb.py index 497d37f8ad..a1469618e1 100644 --- a/src/trio/_core/_concat_tb.py +++ b/src/trio/_core/_concat_tb.py @@ -1,7 +1,9 @@ from __future__ import annotations -from types import TracebackType -from typing import Any, ClassVar, cast +from typing import TYPE_CHECKING, ClassVar, cast + +if TYPE_CHECKING: + from types import TracebackType ################################################################ # concat_tb @@ -86,7 +88,9 @@ def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackT def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType: # tputil.ProxyOperation is PyPy-only, and there's no way to specify # cpython/pypy in current type checkers. - def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[no-any-unimported] + def controller( # type: ignore[no-any-unimported] + operation: tputil.ProxyOperation, + ) -> TracebackType | None: # Rationale for pragma: I looked fairly carefully and tried a few # things, and AFAICT it's not actually possible to get any # 'opname' that isn't __getattr__ or __getattribute__. So there's @@ -99,12 +103,14 @@ def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[ "__getattr__", } and operation.args[0] == "tb_next" - ): # pragma: no cover + ) or TYPE_CHECKING: # pragma: no cover return tb_next - return operation.delegate() # Delegate is reverting to original behaviour + # Delegate is reverting to original behaviour + return operation.delegate() # type: ignore[no-any-return] return cast( - TracebackType, tputil.make_proxy(controller, type(base_tb), base_tb) + "TracebackType", + tputil.make_proxy(controller, type(base_tb), base_tb), ) # Returns proxy to traceback @@ -112,7 +118,8 @@ def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[ # `strict_exception_groups=False`. Once that is retired this function and its helper can # be removed as well. def concat_tb( - head: TracebackType | None, tail: TracebackType | None + head: TracebackType | None, + tail: TracebackType | None, ) -> TracebackType | None: # We have to use an iterative algorithm here, because in the worst case # this might be a RecursionError stack that is by definition too deep to diff --git a/src/trio/_core/_entry_queue.py b/src/trio/_core/_entry_queue.py index 7f1eea8e29..0691de3517 100644 --- a/src/trio/_core/_entry_queue.py +++ b/src/trio/_core/_entry_queue.py @@ -2,7 +2,8 @@ import threading from collections import deque -from typing import TYPE_CHECKING, Callable, NoReturn, Tuple +from collections.abc import Callable +from typing import TYPE_CHECKING, NoReturn import attrs @@ -15,8 +16,9 @@ PosArgsT = TypeVarTuple("PosArgsT") -Function = Callable[..., object] -Job = Tuple[Function, Tuple[object, ...]] +# Explicit "Any" is not allowed +Function = Callable[..., object] # type: ignore[misc] +Job = tuple[Function, tuple[object, ...]] @attrs.define @@ -64,7 +66,9 @@ def run_cb(job: Job) -> None: sync_fn(*args) except BaseException as exc: - async def kill_everything(exc: BaseException) -> NoReturn: + async def kill_everything( # noqa: RUF029 # await not used + exc: BaseException, + ) -> NoReturn: raise exc try: @@ -77,7 +81,7 @@ async def kill_everything(exc: BaseException) -> NoReturn: parent_nursery = _core.current_task().parent_nursery if parent_nursery is None: raise AssertionError( - "Internal error: `parent_nursery` should never be `None`" + "Internal error: `parent_nursery` should never be `None`", ) from exc # pragma: no cover parent_nursery.start_soon(kill_everything, exc) @@ -139,14 +143,14 @@ def run_sync_soon( # wakeup call might trigger an OSError b/c the IO manager has # already been shut down. if idempotent: - self.idempotent_queue[(sync_fn, args)] = None + self.idempotent_queue[sync_fn, args] = None else: self.queue.append((sync_fn, args)) self.wakeup.wakeup_thread_and_signal_safe() @final -@attrs.define(eq=False, hash=False) +@attrs.define(eq=False) class TrioToken(metaclass=NoPublicConstructor): """An opaque object representing a single call to :func:`trio.run`. diff --git a/src/trio/_core/_generated_instrumentation.py b/src/trio/_core/_generated_instrumentation.py index 568b76dffa..d03ef9db7d 100644 --- a/src/trio/_core/_generated_instrumentation.py +++ b/src/trio/_core/_generated_instrumentation.py @@ -3,10 +3,9 @@ # ************************************************************* from __future__ import annotations -import sys from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: @@ -15,6 +14,7 @@ __all__ = ["add_instrument", "remove_instrument"] +@enable_ki_protection def add_instrument(instrument: Instrument) -> None: """Start instrumenting the current run loop with the given instrument. @@ -24,13 +24,13 @@ def add_instrument(instrument: Instrument) -> None: If ``instrument`` is already active, does nothing. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def remove_instrument(instrument: Instrument) -> None: """Stop instrumenting the current run loop with the given instrument. @@ -44,7 +44,6 @@ def remove_instrument(instrument: Instrument) -> None: deactivated. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument) except AttributeError: diff --git a/src/trio/_core/_generated_io_epoll.py b/src/trio/_core/_generated_io_epoll.py index 9f9ad59725..41cbb40650 100644 --- a/src/trio/_core/_generated_io_epoll.py +++ b/src/trio/_core/_generated_io_epoll.py @@ -6,7 +6,7 @@ import sys from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: @@ -18,6 +18,7 @@ __all__ = ["notify_closing", "wait_readable", "wait_writable"] +@enable_ki_protection async def wait_readable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is readable. @@ -40,13 +41,13 @@ async def wait_readable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_writable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is writable. @@ -59,13 +60,13 @@ async def wait_writable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def notify_closing(fd: int | _HasFileNo) -> None: """Notify waiters of the given object that it will be closed. @@ -91,7 +92,6 @@ def notify_closing(fd: int | _HasFileNo) -> None: step, so other tasks won't be able to tell what order they happened in anyway. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: diff --git a/src/trio/_core/_generated_io_kqueue.py b/src/trio/_core/_generated_io_kqueue.py index ff8349dffd..3618f48d73 100644 --- a/src/trio/_core/_generated_io_kqueue.py +++ b/src/trio/_core/_generated_io_kqueue.py @@ -4,13 +4,15 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Callable, ContextManager +from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: import select + from collections.abc import Callable + from contextlib import AbstractContextManager from .._channel import MemoryReceiveChannel from .._file_io import _HasFileNo @@ -29,32 +31,33 @@ ] +@enable_ki_protection def current_kqueue() -> select.kqueue: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def monitor_kevent( ident: int, filter: int -) -> ContextManager[MemoryReceiveChannel[select.kevent]]: +) -> AbstractContextManager[MemoryReceiveChannel[select.kevent]]: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_kevent( ident: int, filter: int, abort_func: Callable[[RaiseCancelT], Abort] ) -> Abort: @@ -62,7 +65,6 @@ async def wait_kevent( anything real. See `#26 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent( ident, filter, abort_func @@ -71,6 +73,7 @@ async def wait_kevent( raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_readable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is readable. @@ -93,13 +96,13 @@ async def wait_readable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_writable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is writable. @@ -112,13 +115,13 @@ async def wait_writable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def notify_closing(fd: int | _HasFileNo) -> None: """Notify waiters of the given object that it will be closed. @@ -144,7 +147,6 @@ def notify_closing(fd: int | _HasFileNo) -> None: step, so other tasks won't be able to tell what order they happened in anyway. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: diff --git a/src/trio/_core/_generated_io_windows.py b/src/trio/_core/_generated_io_windows.py index 5765a40798..b530fda2b6 100644 --- a/src/trio/_core/_generated_io_windows.py +++ b/src/trio/_core/_generated_io_windows.py @@ -4,12 +4,14 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, ContextManager +from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: + from contextlib import AbstractContextManager + from typing_extensions import Buffer from .._channel import MemoryReceiveChannel @@ -32,6 +34,7 @@ ] +@enable_ki_protection async def wait_readable(sock: _HasFileNo | int) -> None: """Block until the kernel reports that the given object is readable. @@ -54,13 +57,13 @@ async def wait_readable(sock: _HasFileNo | int) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_writable(sock: _HasFileNo | int) -> None: """Block until the kernel reports that the given object is writable. @@ -73,13 +76,13 @@ async def wait_writable(sock: _HasFileNo | int) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def notify_closing(handle: Handle | int | _HasFileNo) -> None: """Notify waiters of the given object that it will be closed. @@ -105,33 +108,32 @@ def notify_closing(handle: Handle | int | _HasFileNo) -> None: step, so other tasks won't be able to tell what order they happened in anyway. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def register_with_iocp(handle: int | CData) -> None: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_overlapped(handle_: int | CData, lpOverlapped: CData | int) -> object: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped( handle_, lpOverlapped @@ -140,6 +142,7 @@ async def wait_overlapped(handle_: int | CData, lpOverlapped: CData | int) -> ob raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def write_overlapped( handle: int | CData, data: Buffer, file_offset: int = 0 ) -> int: @@ -148,7 +151,6 @@ async def write_overlapped( `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped( handle, data, file_offset @@ -157,6 +159,7 @@ async def write_overlapped( raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def readinto_overlapped( handle: int | CData, buffer: Buffer, file_offset: int = 0 ) -> int: @@ -165,7 +168,6 @@ async def readinto_overlapped( `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped( handle, buffer, file_offset @@ -174,28 +176,28 @@ async def readinto_overlapped( raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_iocp() -> int: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def monitor_completion_key() -> ( - ContextManager[tuple[int, MemoryReceiveChannel[object]]] + AbstractContextManager[tuple[int, MemoryReceiveChannel[object]]] ): """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() except AttributeError: diff --git a/src/trio/_core/_generated_run.py b/src/trio/_core/_generated_run.py index ac3e0f39d6..db1454e6c7 100644 --- a/src/trio/_core/_generated_run.py +++ b/src/trio/_core/_generated_run.py @@ -3,10 +3,9 @@ # ************************************************************* from __future__ import annotations -import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT, RunStatistics, Task if TYPE_CHECKING: @@ -33,6 +32,7 @@ ] +@enable_ki_protection def current_statistics() -> RunStatistics: """Returns ``RunStatistics``, which contains run-loop-level debugging information. @@ -56,13 +56,13 @@ def current_statistics() -> RunStatistics: other attributes vary between backends. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_statistics() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_time() -> float: """Returns the current time according to Trio's internal clock. @@ -73,36 +73,36 @@ def current_time() -> float: RuntimeError: if not inside a call to :func:`trio.run`. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_time() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_clock() -> Clock: """Returns the current :class:`~trio.abc.Clock`.""" - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_clock() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_root_task() -> Task | None: """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_root_task() except AttributeError: raise RuntimeError("must be called from async context") from None -def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None: +@enable_ki_protection +def reschedule(task: Task, next_send: Outcome[object] = _NO_SEND) -> None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -120,13 +120,13 @@ def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None: raise) from :func:`wait_task_rescheduled`. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def spawn_system_task( async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], *args: Unpack[PosArgT], @@ -184,7 +184,6 @@ def spawn_system_task( Task: the newly spawned task """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.spawn_system_task( async_fn, *args, name=name, context=context @@ -193,18 +192,19 @@ def spawn_system_task( raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_trio_token() -> TrioToken: """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_trio_token() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_all_tasks_blocked(cushion: float = 0.0) -> None: """Block until there are no runnable tasks. @@ -263,7 +263,6 @@ async def test_lock_fairness(): print("FAIL") """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion) except AttributeError: diff --git a/src/trio/_core/_instrumentation.py b/src/trio/_core/_instrumentation.py index c1063b0e3e..40bddd1a23 100644 --- a/src/trio/_core/_instrumentation.py +++ b/src/trio/_core/_instrumentation.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import logging import types -from typing import Any, Callable, Dict, Sequence, TypeVar +from collections.abc import Callable, Sequence +from typing import TypeVar from .._abc import Instrument @@ -8,16 +11,18 @@ INSTRUMENT_LOGGER = logging.getLogger("trio.abc.Instrument") -F = TypeVar("F", bound=Callable[..., Any]) +# Explicit "Any" is not allowed +F = TypeVar("F", bound=Callable[..., object]) # type: ignore[misc] # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. -def _public(fn: F) -> F: +# Explicit "Any" is not allowed +def _public(fn: F) -> F: # type: ignore[misc] return fn -class Instruments(Dict[str, Dict[Instrument, None]]): +class Instruments(dict[str, dict[Instrument, None]]): """A collection of `trio.abc.Instrument` organized by hook. Instrumentation calls are rather expensive, and we don't want a @@ -29,7 +34,7 @@ class Instruments(Dict[str, Dict[Instrument, None]]): __slots__ = () - def __init__(self, incoming: Sequence[Instrument]): + def __init__(self, incoming: Sequence[Instrument]) -> None: self["_all"] = {} for instrument in incoming: self.add_instrument(instrument) @@ -86,7 +91,11 @@ def remove_instrument(self, instrument: Instrument) -> None: if not instruments: del self[hookname] - def call(self, hookname: str, *args: Any) -> None: + def call( + self, + hookname: str, + *args: object, + ) -> None: """Call hookname(*args) on each applicable instrument. You must first check whether there are any instruments installed for diff --git a/src/trio/_core/_io_epoll.py b/src/trio/_core/_io_epoll.py index 1f4ae49f7a..5e05f0813f 100644 --- a/src/trio/_core/_io_epoll.py +++ b/src/trio/_core/_io_epoll.py @@ -198,14 +198,14 @@ class _EpollStatistics: # wanted to about how epoll works. -@attrs.define(eq=False, hash=False) +@attrs.define(eq=False) class EpollIOManager: # Using lambda here because otherwise crash on import with gevent monkey patching # See https://github.com/python-trio/trio/issues/2848 _epoll: select.epoll = attrs.Factory(lambda: select.epoll()) # {fd: EpollWaiters} _registered: defaultdict[int, EpollWaiters] = attrs.Factory( - lambda: defaultdict(EpollWaiters) + lambda: defaultdict(EpollWaiters), ) _force_wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair) _force_wakeup_fd: int | None = None @@ -298,7 +298,7 @@ async def _epoll_wait(self, fd: int | _HasFileNo, attr_name: str) -> None: waiters = self._registered[fd] if getattr(waiters, attr_name) is not None: raise _core.BusyResourceError( - "another task is already reading / writing this fd" + "another task is already reading / writing this fd", ) setattr(waiters, attr_name, _core.current_task()) self._update_registrations(fd) diff --git a/src/trio/_core/_io_kqueue.py b/src/trio/_core/_io_kqueue.py index 07903cb886..2383ed64bd 100644 --- a/src/trio/_core/_io_kqueue.py +++ b/src/trio/_core/_io_kqueue.py @@ -5,7 +5,7 @@ import select import sys from contextlib import contextmanager -from typing import TYPE_CHECKING, Callable, Iterator, Literal +from typing import TYPE_CHECKING, Literal import attrs import outcome @@ -15,6 +15,8 @@ from ._wakeup_socketpair import WakeupSocketpair if TYPE_CHECKING: + from collections.abc import Callable, Iterator + from typing_extensions import TypeAlias from .._channel import MemoryReceiveChannel, MemorySendChannel @@ -44,7 +46,9 @@ class KqueueIOManager: def __attrs_post_init__(self) -> None: force_wakeup_event = select.kevent( - self._force_wakeup.wakeup_sock, select.KQ_FILTER_READ, select.KQ_EV_ADD + self._force_wakeup.wakeup_sock, + select.KQ_FILTER_READ, + select.KQ_EV_ADD, ) self._kqueue.control([force_wakeup_event], 0) self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() @@ -78,7 +82,7 @@ def get_events(self, timeout: float) -> EventResult: events += batch if len(batch) < max_events: break - else: + else: # TODO: test this line timeout = 0 # and loop back to the start return events @@ -90,12 +94,12 @@ def process_events(self, events: EventResult) -> None: self._force_wakeup.drain() continue receiver = self._registered[key] - if event.flags & select.KQ_EV_ONESHOT: + if event.flags & select.KQ_EV_ONESHOT: # TODO: test this branch del self._registered[key] if isinstance(receiver, _core.Task): _core.reschedule(receiver, outcome.Value(event)) else: - receiver.send_nowait(event) + receiver.send_nowait(event) # TODO: test this line # kevent registration is complicated -- e.g. aio submission can # implicitly perform a EV_ADD, and EVFILT_PROC with NOTE_TRACK will @@ -130,7 +134,7 @@ def monitor_kevent( key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( - "attempt to register multiple listeners for same ident/filter pair" + "attempt to register multiple listeners for same ident/filter pair", ) send, recv = open_memory_channel[select.kevent](math.inf) self._registered[key] = send @@ -154,13 +158,13 @@ async def wait_kevent( key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( - "attempt to register multiple listeners for same ident/filter pair" + "attempt to register multiple listeners for same ident/filter pair", ) self._registered[key] = _core.current_task() def abort(raise_cancel: RaiseCancelT) -> Abort: r = abort_func(raise_cancel) - if r is _core.Abort.SUCCEEDED: + if r is _core.Abort.SUCCEEDED: # TODO: test this branch del self._registered[key] return r diff --git a/src/trio/_core/_io_windows.py b/src/trio/_core/_io_windows.py index 3d6fbb14c6..4676e6c5cf 100644 --- a/src/trio/_core/_io_windows.py +++ b/src/trio/_core/_io_windows.py @@ -8,10 +8,8 @@ from contextlib import contextmanager from typing import ( TYPE_CHECKING, - Any, - Callable, - Iterator, Literal, + Protocol, TypeVar, cast, ) @@ -27,6 +25,7 @@ AFDPollFlags, CData, CompletionModes, + CType, ErrorCodes, FileFlags, Handle, @@ -42,6 +41,8 @@ ) if TYPE_CHECKING: + from collections.abc import Callable, Iterator + from typing_extensions import Buffer, TypeAlias from .._channel import MemoryReceiveChannel, MemorySendChannel @@ -153,7 +154,8 @@ # # Unfortunately, the Windows kernel seems to have bugs if you try to issue # multiple simultaneous IOCTL_AFD_POLL operations on the same socket (see -# notes-to-self/afd-lab.py). So if a user calls wait_readable and +# https://github.com/python-trio/trio/wiki/notes-to-self#afd-labpy). +# So if a user calls wait_readable and # wait_writable at the same time, we have to combine those into a single # IOCTL_AFD_POLL. This means we can't just use the wait_overlapped machinery. # Instead we have some dedicated code to handle these operations, and a @@ -250,13 +252,28 @@ class AFDWaiters: current_op: AFDPollOp | None = None +# Just used for internal type checking. +class _AFDHandle(Protocol): + Handle: Handle + Status: int + Events: int + + +# Just used for internal type checking. +class _AFDPollInfo(Protocol): + Timeout: int + NumberOfHandles: int + Exclusive: int + Handles: list[_AFDHandle] + + # We also need to bundle up all the info for a single op into a standalone # object, because we need to keep all these objects alive until the operation # finishes, even if we're throwing it away. @attrs.frozen(eq=False) class AFDPollOp: lpOverlapped: CData - poll_info: Any + poll_info: _AFDPollInfo waiters: AFDWaiters afd_group: AFDGroup @@ -304,7 +321,9 @@ def _check(success: T) -> T: def _get_underlying_socket( - sock: _HasFileNo | int | Handle, *, which: WSAIoctls = WSAIoctls.SIO_BASE_HANDLE + sock: _HasFileNo | int | Handle, + *, + which: WSAIoctls = WSAIoctls.SIO_BASE_HANDLE, ) -> Handle: if hasattr(sock, "fileno"): sock = sock.fileno() @@ -355,7 +374,8 @@ def _get_base_socket(sock: _HasFileNo | int | Handle) -> Handle: sock = sock.fileno() sock = _handle(sock) next_sock = _get_underlying_socket( - sock, which=WSAIoctls.SIO_BSP_HANDLE_POLL + sock, + which=WSAIoctls.SIO_BSP_HANDLE_POLL, ) if next_sock == sock: # If BSP_HANDLE_POLL returns the same socket we already had, @@ -367,7 +387,7 @@ def _get_base_socket(sock: _HasFileNo | int | Handle) -> Handle: "return a different socket. Please file a bug at " "https://github.com/python-trio/trio/issues/new, " "and include the output of running: " - "netsh winsock show catalog" + "netsh winsock show catalog", ) from ex # Otherwise we've gotten at least one layer deeper, so # loop back around to keep digging. @@ -422,7 +442,7 @@ def __init__(self) -> None: self._all_afd_handles: list[Handle] = [] self._iocp = _check( - kernel32.CreateIoCompletionPort(INVALID_HANDLE_VALUE, ffi.NULL, 0, 0) + kernel32.CreateIoCompletionPort(INVALID_HANDLE_VALUE, ffi.NULL, 0, 0), ) self._events = ffi.new("OVERLAPPED_ENTRY[]", MAX_EVENTS) @@ -455,7 +475,8 @@ def __init__(self) -> None: # LSPs can in theory override this, but we believe that it never # actually happens in the wild (except Komodia) select_handle = _get_underlying_socket( - s, which=WSAIoctls.SIO_BSP_HANDLE_SELECT + s, + which=WSAIoctls.SIO_BSP_HANDLE_SELECT, ) try: # LSPs shouldn't override this... @@ -473,7 +494,7 @@ def __init__(self) -> None: "Please file a bug at " "https://github.com/python-trio/trio/issues/new, " "and include the output of running: " - "netsh winsock show catalog" + "netsh winsock show catalog", ) def close(self) -> None: @@ -509,8 +530,11 @@ def force_wakeup(self) -> None: assert self._iocp is not None _check( kernel32.PostQueuedCompletionStatus( - self._iocp, 0, CKeys.FORCE_WAKEUP, ffi.NULL - ) + self._iocp, + 0, + CKeys.FORCE_WAKEUP, + ffi.NULL, + ), ) def get_events(self, timeout: float) -> EventResult: @@ -522,8 +546,13 @@ def get_events(self, timeout: float) -> EventResult: assert self._iocp is not None _check( kernel32.GetQueuedCompletionStatusEx( - self._iocp, self._events, MAX_EVENTS, received, milliseconds, 0 - ) + self._iocp, + self._events, + MAX_EVENTS, + received, + milliseconds, + 0, + ), ) except OSError as exc: if exc.winerror != ErrorCodes.WAIT_TIMEOUT: # pragma: no cover @@ -564,7 +593,8 @@ def process_events(self, received: EventResult) -> None: overlapped = entry.lpOverlapped transferred = entry.dwNumberOfBytesTransferred info = CompletionKeyEventInfo( - lpOverlapped=overlapped, dwNumberOfBytesTransferred=transferred + lpOverlapped=overlapped, + dwNumberOfBytesTransferred=transferred, ) _core.reschedule(waiter, Value(info)) elif entry.lpCompletionKey == CKeys.LATE_CANCEL: @@ -587,7 +617,7 @@ def process_events(self, received: EventResult) -> None: exc = _core.TrioInternalError( f"Failed to cancel overlapped I/O in {waiter.name} and didn't " "receive the completion either. Did you forget to " - "call register_with_iocp()?" + "call register_with_iocp()?", ) # Raising this out of handle_io ensures that # the user will see our message even if some @@ -609,7 +639,8 @@ def process_events(self, received: EventResult) -> None: overlapped = int(ffi.cast("uintptr_t", entry.lpOverlapped)) transferred = entry.dwNumberOfBytesTransferred info = CompletionKeyEventInfo( - lpOverlapped=overlapped, dwNumberOfBytesTransferred=transferred + lpOverlapped=overlapped, + dwNumberOfBytesTransferred=transferred, ) queue.send_nowait(info) @@ -623,8 +654,9 @@ def _register_with_iocp(self, handle_: int | CData, completion_key: int) -> None # Ref: http://www.lenholgate.com/blog/2009/09/interesting-blog-posts-on-high-performance-servers.html _check( kernel32.SetFileCompletionNotificationModes( - handle, CompletionModes.FILE_SKIP_SET_EVENT_ON_HANDLE - ) + handle, + CompletionModes.FILE_SKIP_SET_EVENT_ON_HANDLE, + ), ) ################################################################ @@ -638,8 +670,9 @@ def _refresh_afd(self, base_handle: Handle) -> None: try: _check( kernel32.CancelIoEx( - afd_group.handle, waiters.current_op.lpOverlapped - ) + afd_group.handle, + waiters.current_op.lpOverlapped, + ), ) except OSError as exc: if exc.winerror != ErrorCodes.ERROR_NOT_FOUND: @@ -669,7 +702,7 @@ def _refresh_afd(self, base_handle: Handle) -> None: lpOverlapped = ffi.new("LPOVERLAPPED") - poll_info: Any = ffi.new("AFD_POLL_INFO *") + poll_info = cast("_AFDPollInfo", ffi.new("AFD_POLL_INFO *")) poll_info.Timeout = 2**63 - 1 # INT64_MAX poll_info.NumberOfHandles = 1 poll_info.Exclusive = 0 @@ -682,13 +715,13 @@ def _refresh_afd(self, base_handle: Handle) -> None: kernel32.DeviceIoControl( afd_group.handle, IoControlCodes.IOCTL_AFD_POLL, - poll_info, + cast("CType", poll_info), ffi.sizeof("AFD_POLL_INFO"), - poll_info, + cast("CType", poll_info), ffi.sizeof("AFD_POLL_INFO"), ffi.NULL, lpOverlapped, - ) + ), ) except OSError as exc: if exc.winerror != ErrorCodes.ERROR_IO_PENDING: @@ -814,7 +847,9 @@ def register_with_iocp(self, handle: int | CData) -> None: @_public async def wait_overlapped( - self, handle_: int | CData, lpOverlapped: CData | int + self, + handle_: int | CData, + lpOverlapped: CData | int, ) -> object: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 @@ -822,11 +857,11 @@ async def wait_overlapped( `__. """ handle = _handle(handle_) - if isinstance(lpOverlapped, int): + if isinstance(lpOverlapped, int): # TODO: test this line lpOverlapped = ffi.cast("LPOVERLAPPED", lpOverlapped) - if lpOverlapped in self._overlapped_waiters: + if lpOverlapped in self._overlapped_waiters: # TODO: test this line raise _core.BusyResourceError( - "another task is already waiting on that lpOverlapped" + "another task is already waiting on that lpOverlapped", ) task = _core.current_task() self._overlapped_waiters[lpOverlapped] = task @@ -853,8 +888,11 @@ def abort(raise_cancel_: RaiseCancelT) -> Abort: # does, we'll assume the handle wasn't registered. _check( kernel32.PostQueuedCompletionStatus( - self._iocp, 0, CKeys.LATE_CANCEL, lpOverlapped - ) + self._iocp, + 0, + CKeys.LATE_CANCEL, + lpOverlapped, + ), ) # Keep the lpOverlapped referenced so its address # doesn't get reused until our posted completion @@ -864,7 +902,7 @@ def abort(raise_cancel_: RaiseCancelT) -> Abort: self._posted_too_late_to_cancel.add(lpOverlapped) else: # pragma: no cover raise _core.TrioInternalError( - "CancelIoEx failed with unexpected error" + "CancelIoEx failed with unexpected error", ) from exc return _core.Abort.FAILED @@ -889,7 +927,9 @@ def abort(raise_cancel_: RaiseCancelT) -> Abort: return info async def _perform_overlapped( - self, handle: int | CData, submit_fn: Callable[[_Overlapped], None] + self, + handle: int | CData, + submit_fn: Callable[[_Overlapped], None], ) -> _Overlapped: # submit_fn(lpOverlapped) submits some I/O # it may raise an OSError with ERROR_IO_PENDING @@ -899,18 +939,21 @@ async def _perform_overlapped( # operation will not be cancellable, depending on how Windows is # feeling today. So we need to check for cancellation manually. await _core.checkpoint_if_cancelled() - lpOverlapped = cast(_Overlapped, ffi.new("LPOVERLAPPED")) + lpOverlapped = cast("_Overlapped", ffi.new("LPOVERLAPPED")) try: submit_fn(lpOverlapped) except OSError as exc: if exc.winerror != ErrorCodes.ERROR_IO_PENDING: raise - await self.wait_overlapped(handle, cast(CData, lpOverlapped)) + await self.wait_overlapped(handle, cast("CData", lpOverlapped)) return lpOverlapped @_public async def write_overlapped( - self, handle: int | CData, data: Buffer, file_offset: int = 0 + self, + handle: int | CData, + data: Buffer, + file_offset: int = 0, ) -> int: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 @@ -931,7 +974,7 @@ def submit_write(lpOverlapped: _Overlapped) -> None: len(cbuf), ffi.NULL, lpOverlapped, - ) + ), ) lpOverlapped = await self._perform_overlapped(handle, submit_write) @@ -940,7 +983,10 @@ def submit_write(lpOverlapped: _Overlapped) -> None: @_public async def readinto_overlapped( - self, handle: int | CData, buffer: Buffer, file_offset: int = 0 + self, + handle: int | CData, + buffer: Buffer, + file_offset: int = 0, ) -> int: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 @@ -960,7 +1006,7 @@ def submit_read(lpOverlapped: _Overlapped) -> None: len(cbuf), ffi.NULL, lpOverlapped, - ) + ), ) lpOverlapped = await self._perform_overlapped(handle, submit_read) diff --git a/src/trio/_core/_ki.py b/src/trio/_core/_ki.py index a8431f89db..46a7fdf700 100644 --- a/src/trio/_core/_ki.py +++ b/src/trio/_core/_ki.py @@ -1,26 +1,21 @@ from __future__ import annotations -import inspect import signal import sys -from functools import wraps -from typing import TYPE_CHECKING, Final, Protocol, TypeVar +import types +import weakref +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar import attrs from .._util import is_main_thread - -CallableT = TypeVar("CallableT", bound="Callable[..., object]") -RetT = TypeVar("RetT") +from ._run_context import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: import types from collections.abc import Callable - from typing_extensions import ParamSpec, TypeGuard - - ArgsT = ParamSpec("ArgsT") - + from typing_extensions import Self, TypeGuard # In ordinary single-threaded Python code, when you hit control-C, it raises # an exception and automatically does all the regular unwinding stuff. # @@ -83,20 +78,117 @@ # for any Python program that's written to catch and ignore # KeyboardInterrupt.) -# We use this special string as a unique key into the frame locals dictionary. -# The @ ensures it is not a valid identifier and can't clash with any possible -# real local name. See: https://github.com/python-trio/trio/issues/469 -LOCALS_KEY_KI_PROTECTION_ENABLED: Final = "@TRIO_KI_PROTECTION_ENABLED" +_T = TypeVar("_T") + + +class _IdRef(weakref.ref[_T]): + __slots__ = ("_hash",) + _hash: int + + def __new__( + cls, + ob: _T, + callback: Callable[[Self], object] | None = None, + /, + ) -> Self: + self: Self = weakref.ref.__new__(cls, ob, callback) + self._hash = object.__hash__(ob) + return self + + def __eq__(self, other: object) -> bool: + if self is other: + return True + + if not isinstance(other, _IdRef): + return NotImplemented + + my_obj = None + try: + my_obj = self() + return my_obj is not None and my_obj is other() + finally: + del my_obj + + # we're overriding a builtin so we do need this + def __ne__(self, other: object) -> bool: + return not self == other + + def __hash__(self) -> int: + return self._hash + + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +# see also: https://github.com/python/cpython/issues/88306 +class WeakKeyIdentityDictionary(Generic[_KT, _VT]): + def __init__(self) -> None: + self._data: dict[_IdRef[_KT], _VT] = {} + + def remove( + k: _IdRef[_KT], + selfref: weakref.ref[ + WeakKeyIdentityDictionary[_KT, _VT] + ] = weakref.ref( # noqa: B008 # function-call-in-default-argument + self, + ), + ) -> None: + self = selfref() + if self is not None: + try: # noqa: SIM105 # supressible-exception + del self._data[k] + except KeyError: + pass + + self._remove = remove + + def __getitem__(self, k: _KT) -> _VT: + return self._data[_IdRef(k)] + + def __setitem__(self, k: _KT, v: _VT) -> None: + self._data[_IdRef(k, self._remove)] = v + + +_CODE_KI_PROTECTION_STATUS_WMAP: WeakKeyIdentityDictionary[ + types.CodeType, + bool, +] = WeakKeyIdentityDictionary() + + +# This is to support the async_generator package necessary for aclosing on <3.10 +# functions decorated @async_generator are given this magic property that's a +# reference to the object itself +# see python-trio/async_generator/async_generator/_impl.py +def legacy_isasyncgenfunction( + obj: object, +) -> TypeGuard[Callable[..., types.AsyncGeneratorType[object, object]]]: + return getattr(obj, "_async_gen_function", None) == id(obj) # NB: according to the signal.signal docs, 'frame' can be None on entry to # this function: def ki_protection_enabled(frame: types.FrameType | None) -> bool: + try: + task = GLOBAL_RUN_CONTEXT.task + except AttributeError: + task_ki_protected = False + task_frame = None + else: + task_ki_protected = task._ki_protected + task_frame = task.coro.cr_frame + while frame is not None: - if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals: - return bool(frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED]) + try: + v = _CODE_KI_PROTECTION_STATUS_WMAP[frame.f_code] + except KeyError: + pass + else: + return bool(v) if frame.f_code.co_name == "__del__": return True + if frame is task_frame: + return task_ki_protected frame = frame.f_back return True @@ -117,90 +209,33 @@ def currently_ki_protected() -> bool: return ki_protection_enabled(sys._getframe()) -# This is to support the async_generator package necessary for aclosing on <3.10 -# functions decorated @async_generator are given this magic property that's a -# reference to the object itself -# see python-trio/async_generator/async_generator/_impl.py -def legacy_isasyncgenfunction( - obj: object, -) -> TypeGuard[Callable[..., types.AsyncGeneratorType[object, object]]]: - return getattr(obj, "_async_gen_function", None) == id(obj) +class _SupportsCode(Protocol): + __code__: types.CodeType -def _ki_protection_decorator( - enabled: bool, -) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]: - # The "ignore[return-value]" below is because the inspect functions cast away the - # original return type of fn, making it just CoroutineType[Any, Any, Any] etc. - # ignore[misc] is because @wraps() is passed a callable with Any in the return type. - def decorator(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]: - # In some version of Python, isgeneratorfunction returns true for - # coroutine functions, so we have to check for coroutine functions - # first. - if inspect.iscoroutinefunction(fn): - - @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] - # See the comment for regular generators below - coro = fn(*args, **kwargs) - assert coro.cr_frame is not None, "Coroutine frame should exist" - coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return coro # type: ignore[return-value] - - return wrapper - elif inspect.isgeneratorfunction(fn): - - @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] - # It's important that we inject this directly into the - # generator's locals, as opposed to setting it here and then - # doing 'yield from'. The reason is, if a generator is - # throw()n into, then it may magically pop to the top of the - # stack. And @contextmanager generators in particular are a - # case where we often want KI protection, and which are often - # thrown into! See: - # https://bugs.python.org/issue29590 - gen = fn(*args, **kwargs) - gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return gen # type: ignore[return-value] - - return wrapper - elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn): - - @wraps(fn) # type: ignore[arg-type] - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] - # See the comment for regular generators above - agen = fn(*args, **kwargs) - agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return agen # type: ignore[return-value] - - return wrapper - else: - - @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return fn(*args, **kwargs) +_T_supports_code = TypeVar("_T_supports_code", bound=_SupportsCode) - return wrapper - return decorator +def enable_ki_protection(f: _T_supports_code, /) -> _T_supports_code: + """Decorator to enable KI protection.""" + orig = f + if legacy_isasyncgenfunction(f): + f = f.__wrapped__ # type: ignore -# pyright workaround: https://github.com/microsoft/pyright/issues/5866 -class KIProtectionSignature(Protocol): - __name__: str + _CODE_KI_PROTECTION_STATUS_WMAP[f.__code__] = True + return orig - def __call__(self, f: CallableT, /) -> CallableT: - pass +def disable_ki_protection(f: _T_supports_code, /) -> _T_supports_code: + """Decorator to disable KI protection.""" + orig = f -# the following `type: ignore`s are because we use ParamSpec internally, but want to allow overloads -enable_ki_protection: KIProtectionSignature = _ki_protection_decorator(True) # type: ignore[assignment] -enable_ki_protection.__name__ = "enable_ki_protection" + if legacy_isasyncgenfunction(f): + f = f.__wrapped__ # type: ignore -disable_ki_protection: KIProtectionSignature = _ki_protection_decorator(False) # type: ignore[assignment] -disable_ki_protection.__name__ = "disable_ki_protection" + _CODE_KI_PROTECTION_STATUS_WMAP[f.__code__] = False + return orig @attrs.define(slots=False) diff --git a/src/trio/_core/_local.py b/src/trio/_core/_local.py index dd20776c54..fff1234f59 100644 --- a/src/trio/_core/_local.py +++ b/src/trio/_core/_local.py @@ -16,7 +16,7 @@ class _NoValue: ... @final -@attrs.define(eq=False, hash=False) +@attrs.define(eq=False) class RunVarToken(Generic[T], metaclass=NoPublicConstructor): _var: RunVar[T] previous_value: T | type[_NoValue] = _NoValue @@ -28,7 +28,7 @@ def _empty(cls, var: RunVar[T]) -> RunVarToken[T]: @final -@attrs.define(eq=False, hash=False, repr=False) +@attrs.define(eq=False, repr=False) class RunVar(Generic[T]): """The run-local variant of a context variable. @@ -38,13 +38,13 @@ class RunVar(Generic[T]): """ - _name: str - _default: T | type[_NoValue] = _NoValue + _name: str = attrs.field(alias="name") + _default: T | type[_NoValue] = attrs.field(default=_NoValue, alias="default") def get(self, default: T | type[_NoValue] = _NoValue) -> T: """Gets the value of this :class:`RunVar` for the current run call.""" try: - return cast(T, _run.GLOBAL_RUN_CONTEXT.runner._locals[self]) + return cast("T", _run.GLOBAL_RUN_CONTEXT.runner._locals[self]) except AttributeError: raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: diff --git a/src/trio/_core/_mock_clock.py b/src/trio/_core/_mock_clock.py index 70c4e58a2d..7e85df2f7b 100644 --- a/src/trio/_core/_mock_clock.py +++ b/src/trio/_core/_mock_clock.py @@ -63,7 +63,7 @@ class MockClock(Clock): """ - def __init__(self, rate: float = 0.0, autojump_threshold: float = inf): + def __init__(self, rate: float = 0.0, autojump_threshold: float = inf) -> None: # when the real clock said 'real_base', the virtual time was # 'virtual_base', and since then it's advanced at 'rate' virtual # seconds per real second. diff --git a/src/trio/_core/_parking_lot.py b/src/trio/_core/_parking_lot.py index 663271f8a3..292012ba1e 100644 --- a/src/trio/_core/_parking_lot.py +++ b/src/trio/_core/_parking_lot.py @@ -70,11 +70,13 @@ # See: https://github.com/python-trio/trio/issues/53 from __future__ import annotations +import inspect import math from collections import OrderedDict from typing import TYPE_CHECKING import attrs +import outcome from .. import _core from .._util import final @@ -85,6 +87,37 @@ from ._run import Task +GLOBAL_PARKING_LOT_BREAKER: dict[Task, list[ParkingLot]] = {} + + +def add_parking_lot_breaker(task: Task, lot: ParkingLot) -> None: + """Register a task as a breaker for a lot. See :meth:`ParkingLot.break_lot`. + + raises: + trio.BrokenResourceError: if the task has already exited. + """ + if inspect.getcoroutinestate(task.coro) == inspect.CORO_CLOSED: + raise _core._exceptions.BrokenResourceError( + "Attempted to add already exited task as lot breaker.", + ) + if task not in GLOBAL_PARKING_LOT_BREAKER: + GLOBAL_PARKING_LOT_BREAKER[task] = [lot] + else: + GLOBAL_PARKING_LOT_BREAKER[task].append(lot) + + +def remove_parking_lot_breaker(task: Task, lot: ParkingLot) -> None: + """Deregister a task as a breaker for a lot. See :meth:`ParkingLot.break_lot`""" + try: + GLOBAL_PARKING_LOT_BREAKER[task].remove(lot) + except (KeyError, ValueError): + raise RuntimeError( + "Attempted to remove task as breaker for a lot it is not registered for", + ) from None + if not GLOBAL_PARKING_LOT_BREAKER[task]: + del GLOBAL_PARKING_LOT_BREAKER[task] + + @attrs.frozen class ParkingLotStatistics: """An object containing debugging information for a ParkingLot. @@ -100,7 +133,7 @@ class ParkingLotStatistics: @final -@attrs.define(eq=False, hash=False) +@attrs.define(eq=False) class ParkingLot: """A fair wait queue with cancellation and requeueing. @@ -117,6 +150,7 @@ class ParkingLot: # {task: None}, we just want a deque where we can quickly delete random # items _parked: OrderedDict[Task, None] = attrs.field(factory=OrderedDict, init=False) + broken_by: list[Task] = attrs.field(factory=list, init=False) def __len__(self) -> int: """Returns the number of parked tasks.""" @@ -135,7 +169,15 @@ async def park(self) -> None: """Park the current task until woken by a call to :meth:`unpark` or :meth:`unpark_all`. + Raises: + BrokenResourceError: if attempting to park in a broken lot, or the lot + breaks before we get to unpark. + """ + if self.broken_by: + raise _core.BrokenResourceError( + f"Attempted to park in parking lot broken by {self.broken_by}", + ) task = _core.current_task() self._parked[task] = None task.custom_sleep_data = self @@ -181,7 +223,10 @@ def unpark_all(self) -> list[Task]: @_core.enable_ki_protection def repark( - self, new_lot: ParkingLot, *, count: int | float = 1 # noqa: PYI041 + self, + new_lot: ParkingLot, + *, + count: int | float = 1, # noqa: PYI041 ) -> None: """Move parked tasks from one :class:`ParkingLot` object to another. @@ -230,6 +275,35 @@ def repark_all(self, new_lot: ParkingLot) -> None: """ return self.repark(new_lot, count=len(self)) + def break_lot(self, task: Task | None = None) -> None: + """Break this lot, with ``task`` noted as the task that broke it. + + This causes all parked tasks to raise an error, and any + future tasks attempting to park to error. Unpark & repark become no-ops as the + parking lot is empty. + + The error raised contains a reference to the task sent as a parameter. The task + is also saved in the parking lot in the ``broken_by`` attribute. + """ + if task is None: + task = _core.current_task() + + # if lot is already broken, just mark this as another breaker and return + if self.broken_by: + self.broken_by.append(task) + return + + self.broken_by.append(task) + + for parked_task in self._parked: + _core.reschedule( + parked_task, + outcome.Error( + _core.BrokenResourceError(f"Parking lot broken by {task}"), + ), + ) + self._parked.clear() + def statistics(self) -> ParkingLotStatistics: """Return an object containing debugging information. diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 5453c3602e..5dbaa18cab 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -7,13 +7,12 @@ import random import select import sys -import threading import warnings from collections import deque from contextlib import AbstractAsyncContextManager, contextmanager, suppress from contextvars import copy_context from heapq import heapify, heappop, heappush -from math import inf +from math import inf, isnan from time import perf_counter from typing import ( TYPE_CHECKING, @@ -39,7 +38,9 @@ from ._entry_queue import EntryQueue, TrioToken from ._exceptions import Cancelled, RunFinishedError, TrioInternalError from ._instrumentation import Instruments -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED, KIManager, enable_ki_protection +from ._ki import KIManager, enable_ki_protection +from ._parking_lot import GLOBAL_PARKING_LOT_BREAKER +from ._run_context import GLOBAL_RUN_CONTEXT as GLOBAL_RUN_CONTEXT from ._thread_cache import start_thread_soon from ._traps import ( Abort, @@ -81,14 +82,13 @@ StatusT = TypeVar("StatusT") StatusT_contra = TypeVar("StatusT_contra", contravariant=True) -FnT = TypeVar("FnT", bound="Callable[..., Any]") RetT = TypeVar("RetT") DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: Final = 1000 # Passed as a sentinel -_NO_SEND: Final[Outcome[Any]] = cast("Outcome[Any]", object()) +_NO_SEND: Final[Outcome[object]] = cast("Outcome[object]", object()) # Used to track if an exceptiongroup can be collapsed NONSTRICT_EXCEPTIONGROUP_NOTE = 'This is a "loose" ExceptionGroup, and may be collapsed by Trio if it only contains one exception - typically after `Cancelled` has been stripped from it. Note this has consequences for exception handling, and strict_exception_groups=True is recommended.' @@ -101,7 +101,7 @@ class _NoStatus(metaclass=NoPublicConstructor): # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. -def _public(fn: FnT) -> FnT: +def _public(fn: RetT) -> RetT: return fn @@ -115,7 +115,8 @@ def _public(fn: FnT) -> FnT: _r = random.Random() -def _hypothesis_plugin_setup() -> None: +# no cover because we don't check the hypothesis plugin works with hypothesis +def _hypothesis_plugin_setup() -> None: # pragma: no cover from hypothesis import register_random global _ALLOW_DETERMINISTIC_SCHEDULING @@ -142,7 +143,7 @@ def function_with_unique_name_xyzzy() -> NoReturn: raise else: # pragma: no cover raise TrioInternalError( - "A ZeroDivisionError should have been raised, but it wasn't." + "A ZeroDivisionError should have been raised, but it wasn't.", ) ctx = copy_context() @@ -160,7 +161,7 @@ def function_with_unique_name_xyzzy() -> NoReturn: else: # pragma: no cover raise TrioInternalError( f"The purpose of {function_with_unique_name_xyzzy.__name__} is " - "to raise a ZeroDivisionError, but it didn't." + "to raise a ZeroDivisionError, but it didn't.", ) @@ -219,7 +220,8 @@ def collapse_exception_group( and NONSTRICT_EXCEPTIONGROUP_NOTE in getattr(excgroup, "__notes__", ()) ): exceptions[0].__traceback__ = concat_tb( - excgroup.__traceback__, exceptions[0].__traceback__ + excgroup.__traceback__, + exceptions[0].__traceback__, ) return exceptions[0] elif modified: @@ -542,17 +544,40 @@ class CancelScope: cancelled_caught: bool = attrs.field(default=False, init=False) # Constructor arguments: + _relative_deadline: float = attrs.field( + default=inf, + kw_only=True, + alias="relative_deadline", + ) _deadline: float = attrs.field(default=inf, kw_only=True, alias="deadline") _shield: bool = attrs.field(default=False, kw_only=True, alias="shield") + def __attrs_post_init__(self) -> None: + if isnan(self._deadline): + raise ValueError("deadline must not be NaN") + if isnan(self._relative_deadline): + raise ValueError("relative deadline must not be NaN") + if self._relative_deadline < 0: + raise ValueError("timeout must be non-negative") + if self._relative_deadline != inf and self._deadline != inf: + raise ValueError( + "Cannot specify both a deadline and a relative deadline", + ) + @enable_ki_protection def __enter__(self) -> Self: task = _core.current_task() if self._has_been_entered: raise RuntimeError( - "Each CancelScope may only be used for a single 'with' block" + "Each CancelScope may only be used for a single 'with' block", ) self._has_been_entered = True + + if self._relative_deadline != inf: + assert self._deadline == inf + self._deadline = current_time() + self._relative_deadline + self._relative_deadline = inf + if current_time() >= self._deadline: self.cancel() with self._might_change_registered_deadline(): @@ -564,7 +589,7 @@ def _close(self, exc: BaseException | None) -> BaseException | None: if self._cancel_status is None: new_exc = RuntimeError( f"Cancel scope stack corrupted: attempted to exit {self!r} " - "which had already been exited" + "which had already been exited", ) new_exc.__context__ = exc return new_exc @@ -586,7 +611,7 @@ def _close(self, exc: BaseException | None) -> BaseException | None: # without changing any state. new_exc = RuntimeError( f"Cancel scope stack corrupted: attempted to exit {self!r} " - f"from unrelated {scope_task!r}\n{MISNESTING_ADVICE}" + f"from unrelated {scope_task!r}\n{MISNESTING_ADVICE}", ) new_exc.__context__ = exc return new_exc @@ -598,7 +623,7 @@ def _close(self, exc: BaseException | None) -> BaseException | None: # pass silently. new_exc = RuntimeError( f"Cancel scope stack corrupted: attempted to exit {self!r} " - f"in {scope_task!r} that's still within its child {scope_task._cancel_status._scope!r}\n{MISNESTING_ADVICE}" + f"in {scope_task!r} that's still within its child {scope_task._cancel_status._scope!r}\n{MISNESTING_ADVICE}", ) new_exc.__context__ = exc exc = new_exc @@ -733,13 +758,70 @@ def deadline(self) -> float: this can be overridden by the ``deadline=`` argument to the :class:`~trio.CancelScope` constructor. """ + if self._relative_deadline != inf: + assert self._deadline == inf + warnings.warn( + DeprecationWarning( + "unentered relative cancel scope does not have an absolute deadline. Use `.relative_deadline`", + ), + stacklevel=2, + ) + return current_time() + self._relative_deadline return self._deadline @deadline.setter def deadline(self, new_deadline: float) -> None: + if isnan(new_deadline): + raise ValueError("deadline must not be NaN") + if self._relative_deadline != inf: + assert self._deadline == inf + warnings.warn( + DeprecationWarning( + "unentered relative cancel scope does not have an absolute deadline. Transforming into an absolute cancel scope. First set `.relative_deadline = math.inf` if you do want an absolute cancel scope.", + ), + stacklevel=2, + ) + self._relative_deadline = inf with self._might_change_registered_deadline(): self._deadline = float(new_deadline) + @property + def relative_deadline(self) -> float: + if self._has_been_entered: + return self._deadline - current_time() + elif self._deadline != inf: + assert self._relative_deadline == inf + raise RuntimeError( + "unentered non-relative cancel scope does not have a relative deadline", + ) + return self._relative_deadline + + @relative_deadline.setter + def relative_deadline(self, new_relative_deadline: float) -> None: + if isnan(new_relative_deadline): + raise ValueError("relative deadline must not be NaN") + if new_relative_deadline < 0: + raise ValueError("relative deadline must be non-negative") + if self._has_been_entered: + with self._might_change_registered_deadline(): + self._deadline = current_time() + float(new_relative_deadline) + elif self._deadline != inf: + assert self._relative_deadline == inf + raise RuntimeError( + "unentered non-relative cancel scope does not have a relative deadline", + ) + else: + self._relative_deadline = new_relative_deadline + + @property + def is_relative(self) -> bool | None: + """Returns None after entering. Returns False if both deadline and + relative_deadline are inf.""" + assert not (self._deadline != inf and self._relative_deadline != inf) + if self._has_been_entered: + return None + return self._relative_deadline != inf + @property def shield(self) -> bool: """Read-write, :class:`bool`, default :data:`False`. So long as @@ -845,7 +927,7 @@ def started(self, value: StatusT_contra | None = None) -> None: # This code needs to be read alongside the code from Nursery.start to make # sense. -@attrs.define(eq=False, hash=False, repr=False, slots=False) +@attrs.define(eq=False, repr=False, slots=False) class _TaskStatus(TaskStatus[StatusT]): _old_nursery: Nursery _new_nursery: Nursery @@ -864,7 +946,7 @@ def started(self: _TaskStatus[StatusT], value: StatusT) -> None: ... def started(self, value: StatusT | None = None) -> None: if self._value is not _NoStatus: raise RuntimeError("called 'started' twice on the same task status") - self._value = cast(StatusT, value) # If None, StatusT == None + self._value = cast("StatusT", value) # If None, StatusT == None # If the old nursery is cancelled, then quietly quit now; the child # will eventually exit on its own, and we don't want to risk moving @@ -932,7 +1014,9 @@ async def __aenter__(self) -> Nursery: self._scope = CancelScope() self._scope.__enter__() self._nursery = Nursery._create( - current_task(), self._scope, self.strict_exception_groups + current_task(), + self._scope, + self.strict_exception_groups, ) return self._nursery @@ -970,7 +1054,7 @@ async def __aexit__( def __enter__(self) -> NoReturn: raise RuntimeError( - "use 'async with open_nursery(...)', not 'with open_nursery(...)'" + "use 'async with open_nursery(...)', not 'with open_nursery(...)'", ) def __exit__( @@ -1049,7 +1133,7 @@ def __init__( parent_task: Task, cancel_scope: CancelScope, strict_exception_groups: bool, - ): + ) -> None: self._parent_task = parent_task self._strict_exception_groups = strict_exception_groups parent_task._child_nurseries.append(self) @@ -1092,14 +1176,19 @@ def _check_nursery_closed(self) -> None: self._parent_waiting_in_aexit = False GLOBAL_RUN_CONTEXT.runner.reschedule(self._parent_task) - def _child_finished(self, task: Task, outcome: Outcome[Any]) -> None: + def _child_finished( + self, + task: Task, + outcome: Outcome[object], + ) -> None: self._children.remove(task) if isinstance(outcome, Error): self._add_exc(outcome.error) self._check_nursery_closed() async def _nested_child_finished( - self, nested_child_exc: BaseException | None + self, + nested_child_exc: BaseException | None, ) -> BaseException | None: # Returns ExceptionGroup instance (or any exception if the nursery is in loose mode # and there is just one contained exception) if there are pending exceptions @@ -1138,7 +1227,8 @@ def aborted(raise_cancel: _core.RaiseCancelT) -> Abort: if not self._strict_exception_groups and len(self._pending_excs) == 1: return self._pending_excs[0] exception = BaseExceptionGroup( - "Exceptions from Trio nursery", self._pending_excs + "Exceptions from Trio nursery", + self._pending_excs, ) if not self._strict_exception_groups: exception.add_note(NONSTRICT_EXCEPTIONGROUP_NOTE) @@ -1196,12 +1286,14 @@ def start_soon( """ GLOBAL_RUN_CONTEXT.runner.spawn_impl(async_fn, args, self, name) - async def start( + # Typing changes blocked by https://github.com/python/mypy/pull/17512 + # Explicit "Any" is not allowed + async def start( # type: ignore[misc] self, async_fn: Callable[..., Awaitable[object]], *args: object, name: object = None, - ) -> Any: + ) -> Any | None: r"""Creates and initializes a child task. Like :meth:`start_soon`, but blocks until the new task has @@ -1252,10 +1344,16 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): # set strict_exception_groups = True to make sure we always unwrap # *this* nursery's exceptiongroup async with open_nursery(strict_exception_groups=True) as old_nursery: - task_status: _TaskStatus[Any] = _TaskStatus(old_nursery, self) + task_status: _TaskStatus[object | None] = _TaskStatus( + old_nursery, + self, + ) thunk = functools.partial(async_fn, task_status=task_status) task = GLOBAL_RUN_CONTEXT.runner.spawn_impl( - thunk, args, old_nursery, name + thunk, + args, + old_nursery, + name, ) task._eventual_parent_nursery = self # Wait for either TaskStatus.started or an exception to @@ -1266,7 +1364,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): raise TrioInternalError( "Internal nursery should not have multiple tasks. This can be " 'caused by the user managing to access the "old" nursery in ' - "`task_status` and spawning tasks in it." + "`task_status` and spawning tasks in it.", ) from exc # If we get here, then the child either got reparented or exited @@ -1289,14 +1387,16 @@ def __del__(self) -> None: @final -@attrs.define(eq=False, hash=False, repr=False) -class Task(metaclass=NoPublicConstructor): +@attrs.define(eq=False, repr=False) +class Task(metaclass=NoPublicConstructor): # type: ignore[misc] _parent_nursery: Nursery | None - coro: Coroutine[Any, Outcome[object], Any] + # Explicit "Any" is not allowed + coro: Coroutine[Any, Outcome[object], Any] # type: ignore[misc] _runner: Runner name: str context: contextvars.Context _counter: int = attrs.field(init=False, factory=itertools.count().__next__) + _ki_protected: bool # Invariant: # - for unscheduled tasks, _next_send_fn and _next_send are both None @@ -1309,10 +1409,11 @@ class Task(metaclass=NoPublicConstructor): # tracebacks with extraneous frames. # - for scheduled tasks, custom_sleep_data is None # Tasks start out unscheduled. - _next_send_fn: Callable[[Any], object] | None = None - _next_send: Outcome[Any] | None | BaseException = None + # Explicit "Any" is not allowed + _next_send_fn: Callable[[Any], object] | None = None # type: ignore[misc] + _next_send: Outcome[Any] | BaseException | None = None # type: ignore[misc] _abort_func: Callable[[_core.RaiseCancelT], Abort] | None = None - custom_sleep_data: Any = None + custom_sleep_data: Any = None # type: ignore[misc] # For introspection and nursery.start() _child_nurseries: list[Nursery] = attrs.Factory(list) @@ -1380,7 +1481,7 @@ def print_stack_for_task(task): """ # Ignore static typing as we're doing lots of dynamic introspection - coro: Any = self.coro + coro: Any = self.coro # type: ignore[misc] while coro is not None: if hasattr(coro, "cr_frame"): # A real coroutine @@ -1472,14 +1573,6 @@ def raise_cancel() -> NoReturn: ################################################################ -class RunContext(threading.local): - runner: Runner - task: Task - - -GLOBAL_RUN_CONTEXT: Final = RunContext() - - @attrs.frozen class RunStatistics: """An object containing run-loop-level debugging information. @@ -1532,14 +1625,17 @@ class RunStatistics: # worker thread. -@attrs.define(eq=False, hash=False) -class GuestState: +@attrs.define(eq=False) +# Explicit "Any" is not allowed +class GuestState: # type: ignore[misc] runner: Runner run_sync_soon_threadsafe: Callable[[Callable[[], object]], object] run_sync_soon_not_threadsafe: Callable[[Callable[[], object]], object] - done_callback: Callable[[Outcome[Any]], object] + # Explicit "Any" is not allowed + done_callback: Callable[[Outcome[Any]], object] # type: ignore[misc] unrolled_run_gen: Generator[float, EventResult, None] - unrolled_run_next_send: Outcome[Any] = attrs.Factory(lambda: Value(None)) + # Explicit "Any" is not allowed + unrolled_run_next_send: Outcome[Any] = attrs.Factory(lambda: Value(None)) # type: ignore[misc] def guest_tick(self) -> None: prev_library, sniffio_library.name = sniffio_library.name, "trio" @@ -1557,7 +1653,8 @@ def guest_tick(self) -> None: # Optimization: try to skip going into the thread if we can avoid it events_outcome: Value[EventResult] | Error = capture( - self.runner.io_manager.get_events, 0 + self.runner.io_manager.get_events, + 0, ) if timeout <= 0 or isinstance(events_outcome, Error) or events_outcome.value: # No need to go into the thread @@ -1582,8 +1679,9 @@ def in_main_thread() -> None: start_thread_soon(get_events, deliver) -@attrs.define(eq=False, hash=False) -class Runner: +@attrs.define(eq=False) +# Explicit "Any" is not allowed +class Runner: # type: ignore[misc] clock: Clock instruments: Instruments io_manager: TheIOManager @@ -1591,7 +1689,8 @@ class Runner: strict_exception_groups: bool # Run-local values, see _local.py - _locals: dict[_core.RunVar[Any], Any] = attrs.Factory(dict) + # Explicit "Any" is not allowed + _locals: dict[_core.RunVar[Any], object] = attrs.Factory(dict) # type: ignore[misc] runq: deque[Task] = attrs.Factory(deque) tasks: set[Task] = attrs.Factory(set) @@ -1602,7 +1701,7 @@ class Runner: system_nursery: Nursery | None = None system_context: contextvars.Context = attrs.field(kw_only=True) main_task: Task | None = None - main_task_outcome: Outcome[Any] | None = None + main_task_outcome: Outcome[object] | None = None entry_queue: EntryQueue = attrs.Factory(EntryQueue) trio_token: TrioToken | None = None @@ -1694,10 +1793,8 @@ def current_root_task(self) -> Task | None: # Core task handling primitives ################ - @_public # Type-ignore due to use of Any here. - def reschedule( # type: ignore[misc] - self, task: Task, next_send: Outcome[Any] = _NO_SEND - ) -> None: + @_public + def reschedule(self, task: Task, next_send: Outcome[object] = _NO_SEND) -> None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -1783,13 +1880,17 @@ async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: coro = python_wrapper(coro) assert coro.cr_frame is not None, "Coroutine frame should exist" - coro.cr_frame.f_locals.setdefault(LOCALS_KEY_KI_PROTECTION_ENABLED, system_task) ###### # Set up the Task object ###### task = Task._create( - coro=coro, parent_nursery=nursery, runner=self, name=name, context=context + coro=coro, + parent_nursery=nursery, + runner=self, + name=name, + context=context, + ki_protected=system_task, ) self.tasks.add(task) @@ -1804,7 +1905,13 @@ async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: self.reschedule(task, None) # type: ignore[arg-type] return task - def task_exited(self, task: Task, outcome: Outcome[Any]) -> None: + def task_exited(self, task: Task, outcome: Outcome[object]) -> None: + # break parking lots associated with the exiting task + if task in GLOBAL_PARKING_LOT_BREAKER: + for lot in GLOBAL_PARKING_LOT_BREAKER[task]: + lot.break_lot(task) + del GLOBAL_PARKING_LOT_BREAKER[task] + if ( task._cancel_status is not None and task._cancel_status.abandoned_by_misnesting @@ -1819,7 +1926,7 @@ def task_exited(self, task: Task, outcome: Outcome[Any]) -> None: # traceback frame included raise RuntimeError( "Cancel scope stack corrupted: cancel scope surrounding " - f"{task!r} was closed before the task exited\n{MISNESTING_ADVICE}" + f"{task!r} was closed before the task exited\n{MISNESTING_ADVICE}", ) except RuntimeError as new_exc: if isinstance(outcome, Error): @@ -1932,7 +2039,10 @@ async def init( async with open_nursery() as main_task_nursery: try: self.main_task = self.spawn_impl( - async_fn, args, main_task_nursery, None + async_fn, + args, + main_task_nursery, + None, ) except BaseException as exc: self.main_task_outcome = Error(exc) @@ -2007,7 +2117,8 @@ def _deliver_ki_cb(self) -> None: # sortedcontainers doesn't have types, and is reportedly very hard to type: # https://github.com/grantjenks/python-sortedcontainers/issues/68 - waiting_for_idle: Any = attrs.Factory(SortedDict) + # Explicit "Any" is not allowed + waiting_for_idle: Any = attrs.Factory(SortedDict) # type: ignore[misc] @_public async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None: @@ -2301,14 +2412,15 @@ def run( # Inlined copy of runner.main_task_outcome.unwrap() to avoid # cluttering every single Trio traceback with an extra frame. if isinstance(runner.main_task_outcome, Value): - return cast(RetT, runner.main_task_outcome.value) + return cast("RetT", runner.main_task_outcome.value) elif isinstance(runner.main_task_outcome, Error): raise runner.main_task_outcome.error else: # pragma: no cover raise AssertionError(runner.main_task_outcome) -def start_guest_run( +# Explicit .../"Any" not allowed +def start_guest_run( # type: ignore[misc] async_fn: Callable[..., Awaitable[RetT]], *args: object, run_sync_soon_threadsafe: Callable[[Callable[[], object]], object], @@ -2424,7 +2536,8 @@ def my_done_callback(run_outcome): # this time, so it shouldn't be possible to get an exception here, # except for a TrioInternalError. next_send = cast( - EventResult, None + "EventResult", + None, ) # First iteration must be `None`, every iteration after that is EventResult for _tick in range(5): # expected need is 2 iterations + leave some wiggle room if runner.system_nursery is not None: @@ -2434,13 +2547,13 @@ def my_done_callback(run_outcome): timeout = guest_state.unrolled_run_gen.send(next_send) except StopIteration: # pragma: no cover raise TrioInternalError( - "Guest runner exited before system nursery was initialized" + "Guest runner exited before system nursery was initialized", ) from None if timeout != 0: # pragma: no cover guest_state.unrolled_run_gen.throw( TrioInternalError( - "Guest runner blocked before system nursery was initialized" - ) + "Guest runner blocked before system nursery was initialized", + ), ) # next_send should be the return value of # IOManager.get_events() if no I/O was waiting, which is @@ -2454,8 +2567,8 @@ def my_done_callback(run_outcome): guest_state.unrolled_run_gen.throw( TrioInternalError( "Guest runner yielded too many times before " - "system nursery was initialized" - ) + "system nursery was initialized", + ), ) guest_state.unrolled_run_next_send = Value(next_send) @@ -2471,13 +2584,13 @@ def my_done_callback(run_outcome): # mode", where our core event loop gets unrolled into a series of callbacks on # the host loop. If you're doing a regular trio.run then this gets run # straight through. +@enable_ki_protection def unrolled_run( runner: Runner, async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], args: tuple[Unpack[PosArgT]], host_uses_signal_set_wakeup_fd: bool = False, ) -> Generator[float, EventResult, None]: - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True __tracebackhide__ = True try: @@ -2488,7 +2601,11 @@ def unrolled_run( runner.instruments.call("before_run") runner.clock.start_clock() runner.init_task = runner.spawn_impl( - runner.init, (async_fn, args), None, "", system_task=True + runner.init, + (async_fn, args), + None, + "", + system_task=True, ) # You know how people talk about "event loops"? This 'while' loop right @@ -2607,7 +2724,7 @@ def unrolled_run( next_send_fn = task._next_send_fn next_send = task._next_send task._next_send_fn = task._next_send = None - final_outcome: Outcome[Any] | None = None + final_outcome: Outcome[object] | None = None try: # We used to unwrap the Outcome object here and send/throw # its contents in directly, but it turns out that .throw() @@ -2666,7 +2783,7 @@ def unrolled_run( f"trio.run received unrecognized yield message {msg!r}. " "Are you trying to use a library written for some " "other framework like asyncio? That won't work " - "without some kind of compatibility shim." + "without some kind of compatibility shim.", ) # The foreign library probably doesn't adhere to our # protocol of unwrapping whatever outcome gets sent in. @@ -2690,7 +2807,7 @@ def unrolled_run( warnings.warn( RuntimeWarning( "Trio guest run got abandoned without properly finishing... " - "weird stuff might happen" + "weird stuff might happen", ), stacklevel=1, ) @@ -2716,15 +2833,15 @@ def unrolled_run( ################################################################ -class _TaskStatusIgnored(TaskStatus[Any]): +class _TaskStatusIgnored(TaskStatus[object]): def __repr__(self) -> str: return "TASK_STATUS_IGNORED" - def started(self, value: Any = None) -> None: + def started(self, value: object = None) -> None: pass -TASK_STATUS_IGNORED: Final[TaskStatus[Any]] = _TaskStatusIgnored() +TASK_STATUS_IGNORED: Final[TaskStatus[object]] = _TaskStatusIgnored() def current_task() -> Task: @@ -2841,6 +2958,13 @@ async def checkpoint_if_cancelled() -> None: _KqueueStatistics as IOStatistics, ) else: # pragma: no cover + _patchers = sorted({"eventlet", "gevent"}.intersection(sys.modules)) + if _patchers: + raise NotImplementedError( + "unsupported platform or primitives trio depends on are monkey-patched out by " + + ", ".join(_patchers), + ) + raise NotImplementedError("unsupported platform") from ._generated_instrumentation import * diff --git a/src/trio/_core/_run_context.py b/src/trio/_core/_run_context.py new file mode 100644 index 0000000000..085bff9a34 --- /dev/null +++ b/src/trio/_core/_run_context.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Final + +if TYPE_CHECKING: + from ._run import Runner, Task + + +class RunContext(threading.local): + runner: Runner + task: Task + + +GLOBAL_RUN_CONTEXT: Final = RunContext() diff --git a/src/trio/_core/_tests/test_asyncgen.py b/src/trio/_core/_tests/test_asyncgen.py index aa64a49d70..5db88d7926 100644 --- a/src/trio/_core/_tests/test_asyncgen.py +++ b/src/trio/_core/_tests/test_asyncgen.py @@ -43,7 +43,8 @@ async def example(cause: str) -> AsyncGenerator[int, None]: async def async_main() -> None: # GC'ed before exhausted with pytest.warns( - ResourceWarning, match="Async generator.*collected before.*exhausted" + ResourceWarning, + match="Async generator.*collected before.*exhausted", ): assert await example("abandoned").asend(None) == 42 gc_collect_harder() @@ -109,7 +110,7 @@ async def agen() -> AsyncGenerator[int, None]: await _core.wait_all_tasks_blocked() assert record == ["crashing"] # Following type ignore is because typing for LogCaptureFixture is wrong - exc_type, exc_value, exc_traceback = caplog.records[0].exc_info # type: ignore[misc] + exc_type, exc_value, _exc_traceback = caplog.records[0].exc_info # type: ignore[misc] assert exc_type is ValueError assert str(exc_value) == "oops" assert "during finalization of async generator" in caplog.records[0].message @@ -153,7 +154,8 @@ async def innermost() -> AsyncGenerator[int, None]: record.append("innermost") async def agen( - label: int, inner: AsyncGenerator[int, None] + label: int, + inner: AsyncGenerator[int, None], ) -> AsyncGenerator[int, None]: try: yield await inner.asend(None) @@ -197,7 +199,8 @@ def collect_at_opportune_moment(token: _core._entry_queue.TrioToken) -> None: runner = _core._run.GLOBAL_RUN_CONTEXT.runner assert runner.system_nursery is not None if runner.system_nursery._closed and isinstance( - runner.asyncgens.alive, weakref.WeakSet + runner.asyncgens.alive, + weakref.WeakSet, ): saved.clear() record.append("final collection") @@ -224,8 +227,8 @@ async def async_main() -> None: # failure as small as we want. for _attempt in range(50): needs_retry = False - del record[:] - del saved[:] + record.clear() + saved.clear() _core.run(async_main) if needs_retry: # pragma: no cover assert record == ["cleaned up"] @@ -235,7 +238,7 @@ async def async_main() -> None: else: # pragma: no cover pytest.fail( "Didn't manage to hit the trailing_finalizer_asyncgens case " - f"despite trying {_attempt} times" + f"despite trying {_attempt} times", ) diff --git a/src/trio/_core/_tests/test_exceptiongroup_gc.py b/src/trio/_core/_tests/test_exceptiongroup_gc.py index 8957a581a5..885ef68624 100644 --- a/src/trio/_core/_tests/test_exceptiongroup_gc.py +++ b/src/trio/_core/_tests/test_exceptiongroup_gc.py @@ -76,7 +76,8 @@ def test_concat_tb() -> None: # Unclear if this can still fail, removing the `del` from _concat_tb.copy_tb does not seem # to trigger it (on a platform where the `del` is executed) @pytest.mark.skipif( - sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" + sys.implementation.name != "cpython", + reason="Only makes sense with refcounting GC", ) def test_ExceptionGroup_catch_doesnt_create_cyclic_garbage() -> None: # https://github.com/python-trio/trio/pull/2063 diff --git a/src/trio/_core/_tests/test_guest_mode.py b/src/trio/_core/_tests/test_guest_mode.py index 8972ec735a..526932b949 100644 --- a/src/trio/_core/_tests/test_guest_mode.py +++ b/src/trio/_core/_tests/test_guest_mode.py @@ -2,7 +2,6 @@ import asyncio import contextlib -import contextvars import queue import signal import socket @@ -11,26 +10,25 @@ import time import traceback import warnings +import weakref +from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence from functools import partial from math import inf from typing import ( TYPE_CHECKING, - Any, - AsyncGenerator, - Awaitable, - Callable, NoReturn, TypeVar, + cast, ) import pytest +import sniffio from outcome import Outcome import trio import trio.testing -from trio.abc import Instrument +from trio.abc import Clock, Instrument -from ..._util import signal_raise from .tutil import gc_collect_harder, restore_unraisablehook if TYPE_CHECKING: @@ -39,7 +37,7 @@ from trio._channel import MemorySendChannel T = TypeVar("T") -InHost: TypeAlias = Callable[[object], None] +InHost: TypeAlias = Callable[[Callable[[], object]], None] # The simplest possible "host" loop. @@ -49,12 +47,16 @@ # - final result is returned # - any unhandled exceptions cause an immediate crash def trivial_guest_run( - trio_fn: Callable[..., Awaitable[T]], + trio_fn: Callable[[InHost], Awaitable[T]], *, in_host_after_start: Callable[[], None] | None = None, - **start_guest_run_kwargs: Any, + host_uses_signal_set_wakeup_fd: bool = False, + clock: Clock | None = None, + instruments: Sequence[Instrument] = (), + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = True, ) -> T: - todo: queue.Queue[tuple[str, Outcome[T] | Callable[..., object]]] = queue.Queue() + todo: queue.Queue[tuple[str, Outcome[T] | Callable[[], object]]] = queue.Queue() host_thread = threading.current_thread() @@ -62,7 +64,8 @@ def run_sync_soon_threadsafe(fn: Callable[[], object]) -> None: nonlocal todo if host_thread is threading.current_thread(): # pragma: no cover crash = partial( - pytest.fail, "run_sync_soon_threadsafe called from host thread" + pytest.fail, + "run_sync_soon_threadsafe called from host thread", ) todo.put(("run", crash)) todo.put(("run", fn)) @@ -71,7 +74,8 @@ def run_sync_soon_not_threadsafe(fn: Callable[[], object]) -> None: nonlocal todo if host_thread is not threading.current_thread(): # pragma: no cover crash = partial( - pytest.fail, "run_sync_soon_not_threadsafe called from worker thread" + pytest.fail, + "run_sync_soon_not_threadsafe called from worker thread", ) todo.put(("run", crash)) todo.put(("run", fn)) @@ -86,7 +90,11 @@ def done_callback(outcome: Outcome[T]) -> None: run_sync_soon_threadsafe=run_sync_soon_threadsafe, run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe, done_callback=done_callback, - **start_guest_run_kwargs, + host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd, + clock=clock, + instruments=instruments, + restrict_keyboard_interrupt_to_checkpoints=restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups=strict_exception_groups, ) if in_host_after_start is not None: in_host_after_start() @@ -170,10 +178,16 @@ async def early_task() -> None: assert res == "ok" assert set(record) == {"system task ran", "main task ran", "run_sync_soon cb ran"} - class BadClock: + class BadClock(Clock): def start_clock(self) -> NoReturn: raise ValueError("whoops") + def current_time(self) -> float: + raise NotImplementedError() + + def deadline_to_sleep_time(self, deadline: float) -> float: + raise NotImplementedError() + def after_start_never_runs() -> None: # pragma: no cover pytest.fail("shouldn't get here") @@ -181,7 +195,9 @@ def after_start_never_runs() -> None: # pragma: no cover # are raised out of start_guest_run, not out of the done_callback with pytest.raises(trio.TrioInternalError): trivial_guest_run( - trio_main, clock=BadClock(), in_host_after_start=after_start_never_runs + trio_main, + clock=BadClock(), + in_host_after_start=after_start_never_runs, ) @@ -219,7 +235,8 @@ async def trio_main(in_host: InHost) -> str: def test_guest_mode_sniffio_integration() -> None: - from sniffio import current_async_library, thread_local as sniffio_library + current_async_library = sniffio.current_async_library + sniffio_library = sniffio.thread_local async def trio_main(in_host: InHost) -> str: async def synchronize() -> None: @@ -331,7 +348,7 @@ async def sit_in_wait_all_tasks_blocked(watb_cscope: trio.CancelScope) -> None: await trio.testing.wait_all_tasks_blocked(cushion=9999) raise AssertionError( # pragma: no cover "wait_all_tasks_blocked should *not* return normally, " - "only by cancellation." + "only by cancellation.", ) assert watb_cscope.cancelled_caught @@ -429,33 +446,46 @@ async def abandoned_main(in_host: InHost) -> None: def aiotrio_run( - trio_fn: Callable[..., Awaitable[T]], + trio_fn: Callable[[], Awaitable[T]], *, pass_not_threadsafe: bool = True, - **start_guest_run_kwargs: Any, + run_sync_soon_not_threadsafe: InHost | None = None, + host_uses_signal_set_wakeup_fd: bool = False, + clock: Clock | None = None, + instruments: Sequence[Instrument] = (), + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = True, ) -> T: loop = asyncio.new_event_loop() async def aio_main() -> T: - trio_done_fut = loop.create_future() + nonlocal run_sync_soon_not_threadsafe + trio_done_fut: asyncio.Future[Outcome[T]] = loop.create_future() - def trio_done_callback(main_outcome: Outcome[object]) -> None: + def trio_done_callback(main_outcome: Outcome[T]) -> None: print(f"trio_fn finished: {main_outcome!r}") trio_done_fut.set_result(main_outcome) if pass_not_threadsafe: - start_guest_run_kwargs["run_sync_soon_not_threadsafe"] = loop.call_soon + run_sync_soon_not_threadsafe = cast("InHost", loop.call_soon) trio.lowlevel.start_guest_run( trio_fn, run_sync_soon_threadsafe=loop.call_soon_threadsafe, done_callback=trio_done_callback, - **start_guest_run_kwargs, + run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe, + host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd, + clock=clock, + instruments=instruments, + restrict_keyboard_interrupt_to_checkpoints=restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups=strict_exception_groups, ) - return (await trio_done_fut).unwrap() # type: ignore[no-any-return] + return (await trio_done_fut).unwrap() try: + # can't use asyncio.run because that fails on Windows (3.8, x64, with + # Komodia LSP) and segfaults on Windows (3.9, x64, with Komodia LSP) return loop.run_until_complete(aio_main()) finally: loop.close() @@ -486,7 +516,8 @@ async def trio_main() -> str: raise AssertionError("should never be reached") # pragma: no cover async def aio_pingpong( - from_trio: asyncio.Queue[int], to_trio: MemorySendChannel[int] + from_trio: asyncio.Queue[int], + to_trio: MemorySendChannel[int], ) -> None: print("aio_pingpong!") @@ -525,7 +556,8 @@ async def aio_pingpong( def test_guest_mode_internal_errors( - monkeypatch: pytest.MonkeyPatch, recwarn: pytest.WarningsRecorder + monkeypatch: pytest.MonkeyPatch, + recwarn: pytest.WarningsRecorder, ) -> None: with monkeypatch.context() as m: @@ -551,11 +583,14 @@ async def crash_in_worker_thread_io(in_host: InHost) -> None: t = threading.current_thread() old_get_events = trio._core._run.TheIOManager.get_events - def bad_get_events(*args: Any) -> object: + def bad_get_events( + self: trio._core._run.TheIOManager, + timeout: float, + ) -> trio._core._run.EventResult: if threading.current_thread() is not t: raise ValueError("oh no!") else: - return old_get_events(*args) + return old_get_events(self, timeout) m.setattr("trio._core._run.TheIOManager.get_events", bad_get_events) @@ -573,10 +608,10 @@ def test_guest_mode_ki() -> None: # Check SIGINT in Trio func and in host func async def trio_main(in_host: InHost) -> None: with pytest.raises(KeyboardInterrupt): - signal_raise(signal.SIGINT) + signal.raise_signal(signal.SIGINT) # Host SIGINT should get injected into Trio - in_host(partial(signal_raise, signal.SIGINT)) + in_host(partial(signal.raise_signal, signal.SIGINT)) await trio.sleep(10) with pytest.raises(KeyboardInterrupt) as excinfo: @@ -589,7 +624,7 @@ async def trio_main(in_host: InHost) -> None: final_exc = KeyError("whoa") async def trio_main_raising(in_host: InHost) -> NoReturn: - in_host(partial(signal_raise, signal.SIGINT)) + in_host(partial(signal.raise_signal, signal.SIGINT)) raise final_exc with pytest.raises(KeyboardInterrupt) as excinfo: @@ -624,8 +659,6 @@ async def trio_main(in_host: InHost) -> None: @restore_unraisablehook() def test_guest_mode_asyncgens() -> None: - import sniffio - record = set() async def agen(label: str) -> AsyncGenerator[int, None]: @@ -652,9 +685,49 @@ async def trio_main() -> None: gc_collect_harder() - # Ensure we don't pollute the thread-level context if run under - # an asyncio without contextvars support (3.6) - context = contextvars.copy_context() - context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True) + aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True) assert record == {("asyncio", "asyncio"), ("trio", "trio")} + + +@restore_unraisablehook() +def test_guest_mode_asyncgens_garbage_collection() -> None: + record: set[tuple[str, str, bool]] = set() + + async def agen(label: str) -> AsyncGenerator[int, None]: + class A: + pass + + a = A() + a_wr = weakref.ref(a) + assert sniffio.current_async_library() == label + try: + yield 1 + finally: + library = sniffio.current_async_library() + with contextlib.suppress(trio.Cancelled): + await sys.modules[library].sleep(0) + + del a + if sys.implementation.name == "pypy": + gc_collect_harder() + + record.add((label, library, a_wr() is None)) + + async def iterate_in_aio() -> None: + await agen("asyncio").asend(None) + + async def trio_main() -> None: + task = asyncio.ensure_future(iterate_in_aio()) + done_evt = trio.Event() + task.add_done_callback(lambda _: done_evt.set()) + with trio.fail_after(1): + await done_evt.wait() + + await agen("trio").asend(None) + + gc_collect_harder() + + aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True) + + assert record == {("asyncio", "asyncio", True), ("trio", "trio", True)} diff --git a/src/trio/_core/_tests/test_instrumentation.py b/src/trio/_core/_tests/test_instrumentation.py index 32335ae7fa..ff29ab3acb 100644 --- a/src/trio/_core/_tests/test_instrumentation.py +++ b/src/trio/_core/_tests/test_instrumentation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Container, Iterable, NoReturn +from typing import TYPE_CHECKING, NoReturn import attrs import pytest @@ -9,10 +9,12 @@ from .tutil import check_sequence_matches if TYPE_CHECKING: + from collections.abc import Container, Iterable + from ...lowlevel import Task -@attrs.define(eq=False, hash=False, slots=False) +@attrs.define(eq=False, slots=False) class TaskRecorder(_abc.Instrument): record: list[tuple[str, Task | None]] = attrs.Factory(list) @@ -204,7 +206,7 @@ async def main() -> Task: assert ("after_run", None) in r.record # And we got a log message assert caplog.records[0].exc_info is not None - exc_type, exc_value, exc_traceback = caplog.records[0].exc_info + exc_type, exc_value, _exc_traceback = caplog.records[0].exc_info assert exc_type is ValueError assert str(exc_value) == "oops" assert "Instrument has been disabled" in caplog.records[0].message diff --git a/src/trio/_core/_tests/test_io.py b/src/trio/_core/_tests/test_io.py index acecc9d6c6..379daa025e 100644 --- a/src/trio/_core/_tests/test_io.py +++ b/src/trio/_core/_tests/test_io.py @@ -1,9 +1,12 @@ from __future__ import annotations import random +import select import socket as stdlib_socket +import sys +from collections.abc import Awaitable, Callable from contextlib import suppress -from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar +from typing import TYPE_CHECKING, TypeVar import pytest @@ -39,7 +42,7 @@ def drain_socket(sock: stdlib_socket.socket) -> None: WaitSocket = Callable[[stdlib_socket.socket], Awaitable[object]] -SocketPair = Tuple[stdlib_socket.socket, stdlib_socket.socket] +SocketPair = tuple[stdlib_socket.socket, stdlib_socket.socket] RetT = TypeVar("RetT") @@ -91,7 +94,9 @@ def fileno_wrapper(fileobj: stdlib_socket.socket) -> RetT: @read_socket_test @write_socket_test async def test_wait_basic( - socketpair: SocketPair, wait_readable: WaitSocket, wait_writable: WaitSocket + socketpair: SocketPair, + wait_readable: WaitSocket, + wait_writable: WaitSocket, ) -> None: a, b = socketpair @@ -159,7 +164,7 @@ async def block_on_write() -> None: @read_socket_test async def test_double_read(socketpair: SocketPair, wait_readable: WaitSocket) -> None: - a, b = socketpair + a, _b = socketpair # You can't have two tasks trying to read from a socket at the same time async with _core.open_nursery() as nursery: @@ -172,7 +177,7 @@ async def test_double_read(socketpair: SocketPair, wait_readable: WaitSocket) -> @write_socket_test async def test_double_write(socketpair: SocketPair, wait_writable: WaitSocket) -> None: - a, b = socketpair + a, _b = socketpair # You can't have two tasks trying to write to a socket at the same time fill_socket(a) @@ -193,7 +198,7 @@ async def test_interrupted_by_close( wait_writable: WaitSocket, notify_closing: Callable[[stdlib_socket.socket], object], ) -> None: - a, b = socketpair + a, _b = socketpair async def reader() -> None: with pytest.raises(_core.ClosedResourceError): @@ -215,7 +220,9 @@ async def writer() -> None: @read_socket_test @write_socket_test async def test_socket_simultaneous_read_write( - socketpair: SocketPair, wait_readable: WaitSocket, wait_writable: WaitSocket + socketpair: SocketPair, + wait_readable: WaitSocket, + wait_writable: WaitSocket, ) -> None: record: list[str] = [] @@ -245,7 +252,9 @@ async def w_task(sock: stdlib_socket.socket) -> None: @read_socket_test @write_socket_test async def test_socket_actual_streaming( - socketpair: SocketPair, wait_readable: WaitSocket, wait_writable: WaitSocket + socketpair: SocketPair, + wait_readable: WaitSocket, + wait_writable: WaitSocket, ) -> None: a, b = socketpair @@ -336,6 +345,7 @@ def check(*, expected_readers: int, expected_writers: int) -> None: assert iostats.tasks_waiting_write == expected_writers else: assert iostats.backend == "kqueue" + assert iostats.monitors == 0 assert iostats.tasks_waiting == expected_readers + expected_writers a1, b1 = stdlib_socket.socketpair() @@ -374,6 +384,44 @@ def check(*, expected_readers: int, expected_writers: int) -> None: check(expected_readers=1, expected_writers=0) +@pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning") +async def test_io_manager_kqueue_monitors_statistics() -> None: + def check( + *, + expected_monitors: int, + expected_readers: int, + expected_writers: int, + ) -> None: + statistics = _core.current_statistics() + print(statistics) + iostats = statistics.io_statistics + assert iostats.backend == "kqueue" + assert iostats.monitors == expected_monitors + assert iostats.tasks_waiting == expected_readers + expected_writers + + a1, b1 = stdlib_socket.socketpair() + for sock in [a1, b1]: + sock.setblocking(False) + + with a1, b1: + # let the call_soon_task settle down + await wait_all_tasks_blocked() + + if sys.platform != "win32" and sys.platform != "linux": + # 1 for call_soon_task + check(expected_monitors=0, expected_readers=1, expected_writers=0) + + with _core.monitor_kevent(a1.fileno(), select.KQ_FILTER_READ): + with ( + pytest.raises(_core.BusyResourceError), + _core.monitor_kevent(a1.fileno(), select.KQ_FILTER_READ), + ): + pass # pragma: no cover + check(expected_monitors=1, expected_readers=1, expected_writers=0) + + check(expected_monitors=0, expected_readers=1, expected_writers=0) + + async def test_can_survive_unnotified_close() -> None: # An "unnotified" close is when the user closes an fd/socket/handle # directly, without calling notify_closing first. This should never happen diff --git a/src/trio/_core/_tests/test_ki.py b/src/trio/_core/_tests/test_ki.py index e4241fc762..67c83e8358 100644 --- a/src/trio/_core/_tests/test_ki.py +++ b/src/trio/_core/_tests/test_ki.py @@ -3,14 +3,19 @@ import contextlib import inspect import signal +import sys import threading -from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator +import weakref +from collections.abc import AsyncIterator, Iterator +from typing import TYPE_CHECKING, Callable, TypeVar import outcome import pytest from trio.testing import RaisesGroup +from .tutil import gc_collect_harder + try: from async_generator import async_generator, yield_ except ImportError: # pragma: no cover @@ -18,16 +23,24 @@ from ... import _core from ..._abc import Instrument +from ..._core import _ki from ..._timeouts import sleep -from ..._util import signal_raise from ...testing import wait_all_tasks_blocked if TYPE_CHECKING: + from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Callable, + Generator, + Iterator, + ) + from ..._core import Abort, RaiseCancelT def ki_self() -> None: - signal_raise(signal.SIGINT) + signal.raise_signal(signal.SIGINT) def test_ki_self() -> None: @@ -515,3 +528,179 @@ async def inner() -> None: _core.run(inner) finally: threading._active[thread.ident] = original # type: ignore[attr-defined] + + +_T = TypeVar("_T") + + +def _identity(v: _T) -> _T: + return v + + +@pytest.mark.xfail( + strict=True, + raises=AssertionError, + reason=( + "it was decided not to protect against this case, see discussion in: " + "https://github.com/python-trio/trio/pull/3110#discussion_r1802123644" + ), +) +async def test_ki_does_not_leak_across_different_calls_to_inner_functions() -> None: + assert not _core.currently_ki_protected() + + def factory(enabled: bool) -> Callable[[], bool]: + @_core.enable_ki_protection if enabled else _identity + def decorated() -> bool: + return _core.currently_ki_protected() + + return decorated + + decorated_enabled = factory(True) + decorated_disabled = factory(False) + assert decorated_enabled() + assert not decorated_disabled() + + +async def test_ki_protection_check_does_not_freeze_locals() -> None: + class A: + pass + + a = A() + wr_a = weakref.ref(a) + assert not _core.currently_ki_protected() + del a + if sys.implementation.name == "pypy": + gc_collect_harder() + assert wr_a() is None + + +def test_identity_weakref_internals() -> None: + """To cover the parts WeakKeyIdentityDictionary won't ever reach.""" + + class A: + def __eq__(self, other: object) -> bool: + return False + + a = A() + assert a != a + wr = _ki._IdRef(a) + wr_other_is_self = wr + + # dict always checks identity before equality so we need to do it here + # to cover `if self is other` + assert wr == wr_other_is_self + + # we want to cover __ne__ and `return NotImplemented` + assert wr != object() + + +def test_weak_key_identity_dict_remove_callback_keyerror() -> None: + """We need to cover the KeyError in self._remove.""" + + class A: + def __eq__(self, other: object) -> bool: + return False + + a = A() + assert a != a + d: _ki.WeakKeyIdentityDictionary[A, bool] = _ki.WeakKeyIdentityDictionary() + + d[a] = True + + data_copy = d._data.copy() + d._data.clear() + del a + + gc_collect_harder() # would call sys.unraisablehook if there's a problem + assert data_copy + + +def test_weak_key_identity_dict_remove_callback_selfref_expired() -> None: + """We need to cover the KeyError in self._remove.""" + + class A: + def __eq__(self, other: object) -> bool: + return False + + a = A() + assert a != a + d: _ki.WeakKeyIdentityDictionary[A, bool] = _ki.WeakKeyIdentityDictionary() + + d[a] = True + + data_copy = d._data.copy() + wr_d = weakref.ref(d) + del d + gc_collect_harder() # would call sys.unraisablehook if there's a problem + assert wr_d() is None + del a + gc_collect_harder() + assert data_copy + + +@_core.enable_ki_protection +async def _protected_async_gen_fn() -> AsyncGenerator[None, None]: + yield + + +@_core.enable_ki_protection +async def _protected_async_fn() -> None: + pass + + +@_core.enable_ki_protection +def _protected_gen_fn() -> Generator[None, None, None]: + yield + + +@_core.disable_ki_protection +async def _unprotected_async_gen_fn() -> AsyncGenerator[None, None]: + yield + + +@_core.disable_ki_protection +async def _unprotected_async_fn() -> None: + pass + + +@_core.disable_ki_protection +def _unprotected_gen_fn() -> Generator[None, None, None]: + yield + + +async def _consume_async_generator(agen: AsyncGenerator[None, None]) -> None: + try: + with pytest.raises(StopAsyncIteration): + while True: + await agen.asend(None) + finally: + await agen.aclose() + + +# Explicit .../"Any" is not allowed +def _consume_function_for_coverage( # type: ignore[misc] + fn: Callable[..., object], +) -> None: + result = fn() + if inspect.isasyncgen(result): + result = _consume_async_generator(result) + + assert inspect.isgenerator(result) or inspect.iscoroutine(result) + with pytest.raises(StopIteration): + while True: + result.send(None) + + +def test_enable_disable_ki_protection_passes_on_inspect_flags() -> None: + assert inspect.isasyncgenfunction(_protected_async_gen_fn) + _consume_function_for_coverage(_protected_async_gen_fn) + assert inspect.iscoroutinefunction(_protected_async_fn) + _consume_function_for_coverage(_protected_async_fn) + assert inspect.isgeneratorfunction(_protected_gen_fn) + _consume_function_for_coverage(_protected_gen_fn) + assert inspect.isasyncgenfunction(_unprotected_async_gen_fn) + _consume_function_for_coverage(_unprotected_async_gen_fn) + assert inspect.iscoroutinefunction(_unprotected_async_fn) + _consume_function_for_coverage(_unprotected_async_fn) + assert inspect.isgeneratorfunction(_unprotected_gen_fn) + _consume_function_for_coverage(_unprotected_gen_fn) diff --git a/src/trio/_core/_tests/test_mock_clock.py b/src/trio/_core/_tests/test_mock_clock.py index 6b0f1ca76b..b881c762b0 100644 --- a/src/trio/_core/_tests/test_mock_clock.py +++ b/src/trio/_core/_tests/test_mock_clock.py @@ -92,7 +92,8 @@ async def test_mock_clock_autojump(mock_clock: MockClock) -> None: mock_clock.autojump_threshold = 0 # if the above line didn't take affect immediately, then this would be # bad: - await sleep(100000) + # ignore ASYNC116, not sleep_forever, trying to test a large but finite sleep + await sleep(100000) # noqa: ASYNC116 async def test_mock_clock_autojump_interference(mock_clock: MockClock) -> None: @@ -109,7 +110,8 @@ async def test_mock_clock_autojump_interference(mock_clock: MockClock) -> None: await wait_all_tasks_blocked(0.015) # but the 0.02 limit does apply - await sleep(100000) + # ignore ASYNC116, not sleep_forever, trying to test a large but finite sleep + await sleep(100000) # noqa: ASYNC116 def test_mock_clock_autojump_preset() -> None: diff --git a/src/trio/_core/_tests/test_parking_lot.py b/src/trio/_core/_tests/test_parking_lot.py index 353c1ba45d..809fb2824a 100644 --- a/src/trio/_core/_tests/test_parking_lot.py +++ b/src/trio/_core/_tests/test_parking_lot.py @@ -1,9 +1,18 @@ from __future__ import annotations +import re from typing import TypeVar import pytest +import trio +from trio.lowlevel import ( + add_parking_lot_breaker, + current_task, + remove_parking_lot_breaker, +) +from trio.testing import Matcher, RaisesGroup + from ... import _core from ...testing import wait_all_tasks_blocked from .._parking_lot import ParkingLot @@ -38,7 +47,8 @@ async def waiter(i: int, lot: ParkingLot) -> None: assert len(record) == 6 check_sequence_matches( - record, [{"sleep 0", "sleep 1", "sleep 2"}, {"wake 0", "wake 1", "wake 2"}] + record, + [{"sleep 0", "sleep 1", "sleep 2"}, {"wake 0", "wake 1", "wake 2"}], ) async with _core.open_nursery() as nursery: @@ -74,18 +84,23 @@ async def waiter(i: int, lot: ParkingLot) -> None: lot.unpark(count=2) await wait_all_tasks_blocked() check_sequence_matches( - record, ["sleep 0", "sleep 1", "sleep 2", {"wake 0", "wake 1"}] + record, + ["sleep 0", "sleep 1", "sleep 2", {"wake 0", "wake 1"}], ) lot.unpark_all() with pytest.raises( - ValueError, match=r"^Cannot pop a non-integer number of tasks\.$" + ValueError, + match=r"^Cannot pop a non-integer number of tasks\.$", ): lot.unpark(count=1.5) async def cancellable_waiter( - name: T, lot: ParkingLot, scopes: dict[T, _core.CancelScope], record: list[str] + name: T, + lot: ParkingLot, + scopes: dict[T, _core.CancelScope], + record: list[str], ) -> None: with _core.CancelScope() as scope: scopes[name] = scope @@ -120,7 +135,8 @@ async def test_parking_lot_cancel() -> None: assert len(record) == 6 check_sequence_matches( - record, ["sleep 1", "sleep 2", "sleep 3", "cancelled 2", {"wake 1", "wake 3"}] + record, + ["sleep 1", "sleep 2", "sleep 3", "cancelled 2", {"wake 1", "wake 3"}], ) @@ -208,3 +224,167 @@ async def test_parking_lot_repark_with_count() -> None: "wake 2", ] lot1.unpark_all() + + +async def dummy_task( + task_status: _core.TaskStatus[_core.Task] = trio.TASK_STATUS_IGNORED, +) -> None: + task_status.started(_core.current_task()) + await trio.sleep_forever() + + +async def test_parking_lot_breaker_basic() -> None: + """Test basic functionality for breaking lots.""" + lot = ParkingLot() + task = current_task() + + # defaults to current task + lot.break_lot() + assert lot.broken_by == [task] + + # breaking the lot again with the same task appends another copy in `broken_by` + lot.break_lot() + assert lot.broken_by == [task, task] + + # trying to park in broken lot errors + broken_by_str = re.escape(str([task, task])) + with pytest.raises( + _core.BrokenResourceError, + match=f"^Attempted to park in parking lot broken by {broken_by_str}$", + ): + await lot.park() + + +async def test_parking_lot_break_parking_tasks() -> None: + """Checks that tasks currently waiting to park raise an error when the breaker exits.""" + + async def bad_parker(lot: ParkingLot, scope: _core.CancelScope) -> None: + add_parking_lot_breaker(current_task(), lot) + with scope: + await trio.sleep_forever() + + lot = ParkingLot() + cs = _core.CancelScope() + + # check that parked task errors + with RaisesGroup( + Matcher(_core.BrokenResourceError, match="^Parking lot broken by"), + ): + async with _core.open_nursery() as nursery: + nursery.start_soon(bad_parker, lot, cs) + await wait_all_tasks_blocked() + + nursery.start_soon(lot.park) + await wait_all_tasks_blocked() + + cs.cancel() + + +async def test_parking_lot_breaker_registration() -> None: + lot = ParkingLot() + task = current_task() + + with pytest.raises( + RuntimeError, + match="Attempted to remove task as breaker for a lot it is not registered for", + ): + remove_parking_lot_breaker(task, lot) + + # check that a task can be registered as breaker for the same lot multiple times + add_parking_lot_breaker(task, lot) + add_parking_lot_breaker(task, lot) + remove_parking_lot_breaker(task, lot) + remove_parking_lot_breaker(task, lot) + + with pytest.raises( + RuntimeError, + match="Attempted to remove task as breaker for a lot it is not registered for", + ): + remove_parking_lot_breaker(task, lot) + + # registering a task as breaker on an already broken lot is fine + lot.break_lot() + child_task: _core.Task | None = None + async with trio.open_nursery() as nursery: + child_task = await nursery.start(dummy_task) + assert isinstance(child_task, _core.Task) + add_parking_lot_breaker(child_task, lot) + nursery.cancel_scope.cancel() + assert lot.broken_by == [task, child_task] + + # manually breaking a lot with an already exited task is fine + lot = ParkingLot() + lot.break_lot(child_task) + assert lot.broken_by == [child_task] + + +async def test_parking_lot_breaker_rebreak() -> None: + lot = ParkingLot() + task = current_task() + lot.break_lot() + + # breaking an already broken lot with a different task is allowed + # The nursery is only to create a task we can pass to lot.break_lot + async with trio.open_nursery() as nursery: + child_task = await nursery.start(dummy_task) + lot.break_lot(child_task) + nursery.cancel_scope.cancel() + + assert lot.broken_by == [task, child_task] + + +async def test_parking_lot_multiple_breakers_exit() -> None: + # register multiple tasks as lot breakers, then have them all exit + lot = ParkingLot() + async with trio.open_nursery() as nursery: + child_task1 = await nursery.start(dummy_task) + child_task2 = await nursery.start(dummy_task) + child_task3 = await nursery.start(dummy_task) + assert isinstance(child_task1, _core.Task) + assert isinstance(child_task2, _core.Task) + assert isinstance(child_task3, _core.Task) + add_parking_lot_breaker(child_task1, lot) + add_parking_lot_breaker(child_task2, lot) + add_parking_lot_breaker(child_task3, lot) + nursery.cancel_scope.cancel() + + # I think the order is guaranteed currently, but doesn't hurt to be safe. + assert set(lot.broken_by) == {child_task1, child_task2, child_task3} + + +async def test_parking_lot_breaker_register_exited_task() -> None: + lot = ParkingLot() + child_task: _core.Task | None = None + async with trio.open_nursery() as nursery: + value = await nursery.start(dummy_task) + assert isinstance(value, _core.Task) + child_task = value + nursery.cancel_scope.cancel() + # trying to register an exited task as lot breaker errors + with pytest.raises( + trio.BrokenResourceError, + match="^Attempted to add already exited task as lot breaker.$", + ): + add_parking_lot_breaker(child_task, lot) + + +async def test_parking_lot_break_itself() -> None: + """Break a parking lot, where the breakee is parked. + Doing this is weird, but should probably be supported. + """ + + async def return_me_and_park( + lot: ParkingLot, + *, + task_status: _core.TaskStatus[_core.Task] = trio.TASK_STATUS_IGNORED, + ) -> None: + task_status.started(_core.current_task()) + await lot.park() + + lot = ParkingLot() + with RaisesGroup( + Matcher(_core.BrokenResourceError, match="^Parking lot broken by"), + ): + async with _core.open_nursery() as nursery: + child_task = await nursery.start(return_me_and_park, lot) + lot.break_lot(child_task) diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index ee823cb81a..0d1cf46722 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -9,8 +9,9 @@ import types import weakref from contextlib import ExitStack, contextmanager, suppress -from math import inf -from typing import TYPE_CHECKING, Any, NoReturn, TypeVar, cast +from math import inf, nan +from typing import TYPE_CHECKING, NoReturn, TypeVar +from unittest import mock import outcome import pytest @@ -26,7 +27,7 @@ assert_checkpoints, wait_all_tasks_blocked, ) -from .._run import DEADLINE_HEAP_MIN_PRUNE_THRESHOLD +from .._run import DEADLINE_HEAP_MIN_PRUNE_THRESHOLD, _count_context_run_tb_frames from .tutil import ( check_sequence_matches, create_asyncio_future_in_new_loop, @@ -116,7 +117,7 @@ async def test_nursery_warn_use_async_with() -> None: with on: # type: ignore pass # pragma: no cover excinfo.match( - r"use 'async with open_nursery\(...\)', not 'with open_nursery\(...\)'" + r"use 'async with open_nursery\(...\)', not 'with open_nursery\(...\)'", ) # avoid unawaited coro. @@ -156,7 +157,8 @@ async def looper(whoami: str, record: list[tuple[str, int]]) -> None: nursery.start_soon(looper, "b", record) check_sequence_matches( - record, [{("a", 0), ("b", 0)}, {("a", 1), ("b", 1)}, {("a", 2), ("b", 2)}] + record, + [{("a", 0), ("b", 0)}, {("a", 1), ("b", 1)}, {("a", 2), ("b", 2)}], ) @@ -268,8 +270,8 @@ async def test_current_time_with_mock_clock(mock_clock: _core.MockClock) -> None start = mock_clock.current_time() assert mock_clock.current_time() == _core.current_time() assert mock_clock.current_time() == _core.current_time() - mock_clock.jump(3.14) - assert start + 3.14 == mock_clock.current_time() == _core.current_time() + mock_clock.jump(3.15) + assert start + 3.15 == mock_clock.current_time() == _core.current_time() async def test_current_clock(mock_clock: _core.MockClock) -> None: @@ -364,6 +366,36 @@ async def test_cancel_scope_repr(mock_clock: _core.MockClock) -> None: assert "exited" in repr(scope) +async def test_cancel_scope_validation() -> None: + with pytest.raises( + ValueError, + match="^Cannot specify both a deadline and a relative deadline$", + ): + _core.CancelScope(deadline=7, relative_deadline=3) + + with pytest.raises(ValueError, match="^deadline must not be NaN$"): + _core.CancelScope(deadline=nan) + with pytest.raises(ValueError, match="^relative deadline must not be NaN$"): + _core.CancelScope(relative_deadline=nan) + + with pytest.raises(ValueError, match="^timeout must be non-negative$"): + _core.CancelScope(relative_deadline=-3) + + scope = _core.CancelScope() + + with pytest.raises(ValueError, match="^deadline must not be NaN$"): + scope.deadline = nan + with pytest.raises(ValueError, match="^relative deadline must not be NaN$"): + scope.relative_deadline = nan + + with pytest.raises(ValueError, match="^relative deadline must be non-negative$"): + scope.relative_deadline = -3 + scope.relative_deadline = 5 + assert scope.relative_deadline == 5 + + # several related tests of CancelScope are implicitly handled by test_timeouts.py + + def test_cancel_points() -> None: async def main1() -> None: with _core.CancelScope() as scope: @@ -434,7 +466,10 @@ async def crasher() -> NoReturn: # nursery block continue propagating to reach the # outer scope. with RaisesGroup( - _core.Cancelled, _core.Cancelled, _core.Cancelled, KeyError + _core.Cancelled, + _core.Cancelled, + _core.Cancelled, + KeyError, ) as excinfo: async with _core.open_nursery() as nursery: # Two children that get cancelled by the nursery scope @@ -769,7 +804,8 @@ async def task2() -> None: nursery.cancel_scope.__exit__(None, None, None) finally: with pytest.raises( - RuntimeError, match="which had already been exited" + RuntimeError, + match="which had already been exited", ) as exc_info: await nursery_mgr.__aexit__(*sys.exc_info()) @@ -797,7 +833,9 @@ async def task3(task_status: _core.TaskStatus[_core.CancelScope]) -> None: await sleep_forever() async with _core.open_nursery() as nursery: - scope: _core.CancelScope = await nursery.start(task3) + value = await nursery.start(task3) + assert isinstance(value, _core.CancelScope) + scope: _core.CancelScope = value with pytest.raises(RuntimeError, match="from unrelated"): scope.__exit__(None, None, None) scope.cancel() @@ -946,7 +984,7 @@ async def main() -> None: # the second exceptiongroup is from the second nursery opened in Runner.init() # the third exceptongroup is from the nursery defined in `system_task` above assert RaisesGroup(RaisesGroup(RaisesGroup(KeyError, ValueError))).matches( - excinfo.value.__cause__ + excinfo.value.__cause__, ) @@ -976,7 +1014,7 @@ async def main() -> None: # See explanation for triple-wrap in test_system_task_crash_ExceptionGroup assert RaisesGroup(RaisesGroup(RaisesGroup(ValueError))).matches( - excinfo.value.__cause__ + excinfo.value.__cause__, ) @@ -1126,13 +1164,18 @@ async def child() -> None: await sleep_forever() with RaisesGroup( - Matcher(ValueError, "error text", lambda e: isinstance(e.__context__, KeyError)) + Matcher( + ValueError, + "error text", + lambda e: isinstance(e.__context__, KeyError), + ), ): async with _core.open_nursery() as nursery: nursery.start_soon(child) await wait_all_tasks_blocked() _core.reschedule( - not_none(child_task), outcome.Error(ValueError("error text")) + not_none(child_task), + outcome.Error(ValueError("error text")), ) @@ -1195,13 +1238,14 @@ async def inner() -> None: "^Unique Text$", lambda e: isinstance(e.__context__, IndexError) and isinstance(e.__context__.__context__, KeyError), - ) + ), ): async with _core.open_nursery() as nursery: nursery.start_soon(child) await wait_all_tasks_blocked() _core.reschedule( - not_none(child_task), outcome.Error(ValueError("Unique Text")) + not_none(child_task), + outcome.Error(ValueError("Unique Text")), ) @@ -1614,7 +1658,10 @@ async def func1(expected: str) -> None: async def func2() -> None: # pragma: no cover pass - async def check(spawn_fn: Callable[..., object]) -> None: + # Explicit .../"Any" is not allowed + async def check( # type: ignore[misc] + spawn_fn: Callable[..., object], + ) -> None: spawn_fn(func1, "func1") spawn_fn(func1, "func2", name=func2) spawn_fn(func1, "func3", name="func3") @@ -1649,13 +1696,14 @@ async def test_current_effective_deadline(mock_clock: _core.MockClock) -> None: def test_nice_error_on_bad_calls_to_run_or_spawn() -> None: - def bad_call_run( + # Explicit .../"Any" is not allowed + def bad_call_run( # type: ignore[misc] func: Callable[..., Awaitable[object]], *args: tuple[object, ...], ) -> None: _core.run(func, *args) - def bad_call_spawn( + def bad_call_spawn( # type: ignore[misc] func: Callable[..., Awaitable[object]], *args: tuple[object, ...], ) -> None: @@ -1680,7 +1728,8 @@ async def async_gen(arg: T) -> AsyncGenerator[T, None]: # pragma: no cover ): bad_call_run(f()) # type: ignore[arg-type] with pytest.raises( - TypeError, match="expected an async function but got an async generator" + TypeError, + match="expected an async function but got an async generator", ): bad_call_run(async_gen, 0) # type: ignore @@ -1690,7 +1739,7 @@ async def async_gen(arg: T) -> AsyncGenerator[T, None]: # pragma: no cover bad_call_spawn(f()) # type: ignore[arg-type] with RaisesGroup( - Matcher(TypeError, "expected an async function but got an async generator") + Matcher(TypeError, "expected an async function but got an async generator"), ): bad_call_spawn(async_gen, 0) # type: ignore @@ -1768,7 +1817,9 @@ async def no_args() -> None: # pragma: no cover await nursery.start(no_args) async def sleep_then_start( - seconds: int, *, task_status: _core.TaskStatus[int] = _core.TASK_STATUS_IGNORED + seconds: int, + *, + task_status: _core.TaskStatus[int] = _core.TASK_STATUS_IGNORED, ) -> None: repr(task_status) # smoke test await sleep(seconds) @@ -1847,7 +1898,8 @@ async def just_started( # but if the task does not execute any checkpoints, and exits, then start() # doesn't raise Cancelled, since the task completed successfully. async def started_with_no_checkpoint( - *, task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED + *, + task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, ) -> None: task_status.started(None) @@ -1861,7 +1913,8 @@ async def started_with_no_checkpoint( # the child crashes after calling started(), the error can *still* come # out of start() async def raise_keyerror_after_started( - *, task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED + *, + task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, ) -> None: task_status.started() raise KeyError("whoopsiedaisy") @@ -1922,14 +1975,19 @@ async def sleeping_children( # Cancelling the setup_nursery just *before* calling started() async with _core.open_nursery() as nursery: - target_nursery: _core.Nursery = await nursery.start(setup_nursery) + value = await nursery.start(setup_nursery) + assert isinstance(value, _core.Nursery) + target_nursery: _core.Nursery = value await target_nursery.start( - sleeping_children, target_nursery.cancel_scope.cancel + sleeping_children, + target_nursery.cancel_scope.cancel, ) # Cancelling the setup_nursery just *after* calling started() async with _core.open_nursery() as nursery: - target_nursery = await nursery.start(setup_nursery) + value = await nursery.start(setup_nursery) + assert isinstance(value, _core.Nursery) + target_nursery = value await target_nursery.start(sleeping_children, lambda: None) target_nursery.cancel_scope.cancel() @@ -2014,7 +2072,10 @@ def __init__(self, *largs: it) -> None: self.nexts = [obj.__anext__ for obj in largs] async def _accumulate( - self, f: Callable[[], Awaitable[int]], items: list[int], i: int + self, + f: Callable[[], Awaitable[int]], + items: list[int], + i: int, ) -> None: items[i] = await f() @@ -2036,7 +2097,8 @@ async def __anext__(self) -> list[int]: # We could also use RaisesGroup, but that's primarily meant as # test infra, not as a runtime tool. if len(e.exceptions) == 1 and isinstance( - e.exceptions[0], StopAsyncIteration + e.exceptions[0], + StopAsyncIteration, ): raise e.exceptions[0] from None else: # pragma: no cover @@ -2243,8 +2305,10 @@ async def detachable_coroutine( await sleep(0) nonlocal task, pdco_outcome task = _core.current_task() - pdco_outcome = await outcome.acapture( - _core.permanently_detach_coroutine_object, task_outcome + # `No overload variant of "acapture" matches argument types "Callable[[Outcome[object]], Coroutine[Any, Any, object]]", "Outcome[None]"` + pdco_outcome = await outcome.acapture( # type: ignore[call-overload] + _core.permanently_detach_coroutine_object, + task_outcome, ) await async_yield(yield_value) @@ -2255,10 +2319,11 @@ async def detachable_coroutine( # is still iterable. At that point anything can be sent into the coroutine, so the .coro type # is wrong. assert pdco_outcome is None - assert not_none(task).coro.send(cast(Any, "be free!")) == "I'm free!" + # `Argument 1 to "send" of "Coroutine" has incompatible type "str"; expected "Outcome[object]"` + assert not_none(task).coro.send("be free!") == "I'm free!" # type: ignore[arg-type] assert pdco_outcome == outcome.Value("be free!") with pytest.raises(StopIteration): - not_none(task).coro.send(cast(Any, None)) + not_none(task).coro.send(None) # type: ignore[arg-type] # Check the exception paths too task = None @@ -2271,7 +2336,7 @@ async def detachable_coroutine( assert not_none(task).coro.throw(throw_in) == "uh oh" assert pdco_outcome == outcome.Error(throw_in) with pytest.raises(StopIteration): - task.coro.send(cast(Any, None)) + task.coro.send(None) async def bad_detach() -> None: async with _core.open_nursery(): @@ -2308,7 +2373,8 @@ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: # pragma: no cover with pytest.raises(RuntimeError) as excinfo: await _core.reattach_detached_coroutine_object( - not_none(unrelated_task), None + not_none(unrelated_task), + None, ) assert "does not match" in str(excinfo.value) @@ -2322,9 +2388,10 @@ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: # pragma: no cover await wait_all_tasks_blocked() # Okay, it's detached. Here's our coroutine runner: - assert not_none(task).coro.send(cast(Any, "not trio!")) == 1 - assert not_none(task).coro.send(cast(Any, None)) == 2 - assert not_none(task).coro.send(cast(Any, None)) == "byebye" + # `Argument 1 to "send" of "Coroutine" has incompatible type "str"; expected "Outcome[object]"` + assert not_none(task).coro.send("not trio!") == 1 # type: ignore[arg-type] + assert not_none(task).coro.send(None) == 2 # type: ignore[arg-type] + assert not_none(task).coro.send(None) == "byebye" # type: ignore[arg-type] # Now it's been reattached, and we can leave the nursery @@ -2354,7 +2421,8 @@ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: await wait_all_tasks_blocked() assert task is not None nursery.cancel_scope.cancel() - task.coro.send(cast(Any, None)) + # `Argument 1 to "send" of "Coroutine" has incompatible type "None"; expected "Outcome[object]"` + task.coro.send(None) # type: ignore[arg-type] assert abort_fn_called @@ -2409,7 +2477,8 @@ async def test_cancel_scope_deadline_duplicates() -> None: # refer to this only seems to break test_cancel_scope_exit_doesnt_create_cyclic_garbage # We're keeping it for now to cover Outcome and potential future refactoring @pytest.mark.skipif( - sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" + sys.implementation.name != "cpython", + reason="Only makes sense with refcounting GC", ) async def test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage() -> None: # https://github.com/python-trio/trio/issues/1770 @@ -2448,7 +2517,8 @@ async def crasher() -> NoReturn: @pytest.mark.skipif( - sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" + sys.implementation.name != "cpython", + reason="Only makes sense with refcounting GC", ) async def test_cancel_scope_exit_doesnt_create_cyclic_garbage() -> None: # https://github.com/python-trio/trio/pull/2063 @@ -2459,9 +2529,13 @@ async def crasher() -> NoReturn: old_flags = gc.get_debug() try: + # fmt: off + # Remove after 3.9 unsupported, black formats in a way that breaks if + # you do `-X oldparser` with RaisesGroup( - Matcher(ValueError, "^this is a crash$") + Matcher(ValueError, "^this is a crash$"), ), _core.CancelScope() as outer: + # fmt: on async with _core.open_nursery() as nursery: gc.collect() gc.set_debug(gc.DEBUG_SAVEALL) @@ -2482,7 +2556,8 @@ async def crasher() -> NoReturn: @pytest.mark.skipif( - sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" + sys.implementation.name != "cpython", + reason="Only makes sense with refcounting GC", ) async def test_nursery_cancel_doesnt_create_cyclic_garbage() -> None: collected = False @@ -2518,7 +2593,8 @@ def toggle_collected() -> None: @pytest.mark.skipif( - sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" + sys.implementation.name != "cpython", + reason="Only makes sense with refcounting GC", ) async def test_locals_destroyed_promptly_on_cancel() -> None: destroyed = False @@ -2550,13 +2626,15 @@ def _create_kwargs(strictness: bool | None) -> dict[str, bool]: @pytest.mark.filterwarnings( - "ignore:.*strict_exception_groups=False:trio.TrioDeprecationWarning" + "ignore:.*strict_exception_groups=False:trio.TrioDeprecationWarning", ) @pytest.mark.parametrize("run_strict", [True, False, None]) @pytest.mark.parametrize("open_nursery_strict", [True, False, None]) @pytest.mark.parametrize("multiple_exceptions", [True, False]) def test_setting_strict_exception_groups( - run_strict: bool | None, open_nursery_strict: bool | None, multiple_exceptions: bool + run_strict: bool | None, + open_nursery_strict: bool | None, + multiple_exceptions: bool, ) -> None: """ Test default values and that nurseries can both inherit and override the global context @@ -2593,7 +2671,7 @@ def run_main() -> None: @pytest.mark.filterwarnings( - "ignore:.*strict_exception_groups=False:trio.TrioDeprecationWarning" + "ignore:.*strict_exception_groups=False:trio.TrioDeprecationWarning", ) @pytest.mark.parametrize("strict", [True, False, None]) async def test_nursery_collapse(strict: bool | None) -> None: @@ -2635,7 +2713,7 @@ async def test_cancel_scope_no_cancellederror() -> None: @pytest.mark.filterwarnings( - "ignore:.*strict_exception_groups=False:trio.TrioDeprecationWarning" + "ignore:.*strict_exception_groups=False:trio.TrioDeprecationWarning", ) @pytest.mark.parametrize("run_strict", [False, True]) @pytest.mark.parametrize("start_raiser_strict", [False, True, None]) @@ -2670,7 +2748,7 @@ async def raiser(*, task_status: _core.TaskStatus[None]) -> None: async def start_raiser() -> None: try: async with _core.open_nursery( - strict_exception_groups=start_raiser_strict + strict_exception_groups=start_raiser_strict, ) as nursery: await nursery.start(raiser) except BaseExceptionGroup as exc_group: @@ -2680,7 +2758,8 @@ async def start_raiser() -> None: # exception group raised by trio with a more specific one (subtype, # different message, etc.). raise BaseExceptionGroup( - "start_raiser nursery custom message", exc_group.exceptions + "start_raiser nursery custom message", + exc_group.exceptions, ) from None raise @@ -2723,3 +2802,56 @@ async def spawn_tasks_in_old_nursery(task_status: _core.TaskStatus[None]) -> Non with pytest.raises(_core.TrioInternalError) as excinfo: await nursery.start(spawn_tasks_in_old_nursery) assert RaisesGroup(ValueError, ValueError).matches(excinfo.value.__cause__) + + +if sys.version_info >= (3, 11): + + def no_other_refs() -> list[object]: + return [] + +else: + + def no_other_refs() -> list[object]: + return [sys._getframe(1)] + + +@pytest.mark.skipif( + sys.implementation.name != "cpython", + reason="Only makes sense with refcounting GC", +) +async def test_ki_protection_doesnt_leave_cyclic_garbage() -> None: + class MyException(Exception): + pass + + async def demo() -> None: + async def handle_error() -> None: + try: + raise MyException + except MyException as e: + exceptions.append(e) + + exceptions: list[MyException] = [] + try: + async with _core.open_nursery() as n: + n.start_soon(handle_error) + raise ExceptionGroup("errors", exceptions) + finally: + exceptions = [] + + exc: Exception | None = None + try: + await demo() + except ExceptionGroup as excs: + exc = excs.exceptions[0] + + assert isinstance(exc, MyException) + assert gc.get_referrers(exc) == no_other_refs() + + +def test_context_run_tb_frames() -> None: + class Context: + def run(self, fn: Callable[[], object]) -> object: + return fn() + + with mock.patch("trio._core._run.copy_context", return_value=Context()): + assert _count_context_run_tb_frames() == 1 diff --git a/src/trio/_core/_tests/test_thread_cache.py b/src/trio/_core/_tests/test_thread_cache.py index ee301d17fd..1e3841ee0d 100644 --- a/src/trio/_core/_tests/test_thread_cache.py +++ b/src/trio/_core/_tests/test_thread_cache.py @@ -4,7 +4,7 @@ import time from contextlib import contextmanager from queue import Queue -from typing import TYPE_CHECKING, Iterator, NoReturn +from typing import TYPE_CHECKING, NoReturn import pytest @@ -13,6 +13,8 @@ from .tutil import gc_collect_harder, slow if TYPE_CHECKING: + from collections.abc import Iterator + from outcome import Outcome diff --git a/src/trio/_core/_tests/test_windows.py b/src/trio/_core/_tests/test_windows.py index c65f0a863d..e548326935 100644 --- a/src/trio/_core/_tests/test_windows.py +++ b/src/trio/_core/_tests/test_windows.py @@ -54,7 +54,8 @@ def test_winerror(monkeypatch: pytest.MonkeyPatch) -> None: mock.return_value = (12, "test error") with pytest.raises( - OSError, match=r"^\[WinError 12\] test error: 'file_1' -> 'file_2'$" + OSError, + match=r"^\[WinError 12\] test error: 'file_1' -> 'file_2'$", ) as exc: raise_winerror(filename="file_1", filename2="file_2") mock.assert_called_once_with() @@ -66,7 +67,8 @@ def test_winerror(monkeypatch: pytest.MonkeyPatch) -> None: # With an explicit number passed in, it overrides what getwinerror() returns. with pytest.raises( - OSError, match=r"^\[WinError 18\] test error: 'a/file' -> 'b/file'$" + OSError, + match=r"^\[WinError 18\] test error: 'a/file' -> 'b/file'$", ) as exc: raise_winerror(18, filename="a/file", filename2="b/file") mock.assert_called_once_with(18) @@ -109,7 +111,8 @@ async def test_readinto_overlapped() -> None: with tempfile.TemporaryDirectory() as tdir: tfile = os.path.join(tdir, "numbers.txt") with open( # noqa: ASYNC230 # This is a test, synchronous is ok - tfile, "wb" + tfile, + "wb", ) as fp: fp.write(data) fp.flush() @@ -133,7 +136,9 @@ async def test_readinto_overlapped() -> None: async def read_region(start: int, end: int) -> None: await _core.readinto_overlapped( - handle, buffer_view[start:end], start + handle, + buffer_view[start:end], + start, ) _core.register_with_iocp(handle) @@ -176,7 +181,10 @@ async def main() -> None: try: async with _core.open_nursery() as nursery: nursery.start_soon( - _core.readinto_overlapped, read_handle, target, name="xyz" + _core.readinto_overlapped, + read_handle, + target, + name="xyz", ) await wait_all_tasks_blocked() nursery.cancel_scope.cancel() @@ -233,7 +241,9 @@ def test_lsp_that_hooks_select_gives_good_error( from .._windows_cffi import CData, WSAIoctls, _handle def patched_get_underlying( - sock: int | CData, *, which: int = WSAIoctls.SIO_BASE_HANDLE + sock: int | CData, + *, + which: int = WSAIoctls.SIO_BASE_HANDLE, ) -> CData: if hasattr(sock, "fileno"): # pragma: no branch sock = sock.fileno() @@ -244,7 +254,8 @@ def patched_get_underlying( monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying) with pytest.raises( - RuntimeError, match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ" + RuntimeError, + match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ", ): _core.run(sleep, 0) @@ -261,7 +272,9 @@ def test_lsp_that_completely_hides_base_socket_gives_good_error( from .._windows_cffi import CData, WSAIoctls, _handle def patched_get_underlying( - sock: int | CData, *, which: int = WSAIoctls.SIO_BASE_HANDLE + sock: int | CData, + *, + which: int = WSAIoctls.SIO_BASE_HANDLE, ) -> CData: if hasattr(sock, "fileno"): # pragma: no branch sock = sock.fileno() diff --git a/src/trio/_core/_tests/tutil.py b/src/trio/_core/_tests/tutil.py index 81370ed76e..063fa1dd80 100644 --- a/src/trio/_core/_tests/tutil.py +++ b/src/trio/_core/_tests/tutil.py @@ -12,7 +12,7 @@ import pytest -# See trio/_tests/conftest.py for the other half of this +# See trio/_tests/pytest_plugin.py for the other half of this from trio._tests.pytest_plugin import RUN_SLOW if TYPE_CHECKING: diff --git a/src/trio/_core/_tests/type_tests/nursery_start.py b/src/trio/_core/_tests/type_tests/nursery_start.py index 77667590b9..4ce03b2721 100644 --- a/src/trio/_core/_tests/type_tests/nursery_start.py +++ b/src/trio/_core/_tests/type_tests/nursery_start.py @@ -1,9 +1,12 @@ """Test variadic generic typing for Nursery.start[_soon]().""" -from typing import Awaitable, Callable +from typing import TYPE_CHECKING from trio import TASK_STATUS_IGNORED, Nursery, TaskStatus +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + async def task_0() -> None: ... @@ -47,7 +50,6 @@ async def task_requires_start(*, task_status: TaskStatus[str]) -> None: async def task_pos_or_kw(value: str, task_status: TaskStatus[int]) -> None: """Check a function which doesn't use the *-syntax works.""" - ... def check_start_soon(nursery: Nursery) -> None: diff --git a/src/trio/_core/_tests/type_tests/run.py b/src/trio/_core/_tests/type_tests/run.py index c121ce6c7a..5c51b91496 100644 --- a/src/trio/_core/_tests/type_tests/run.py +++ b/src/trio/_core/_tests/type_tests/run.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import Sequence, overload +from typing import TYPE_CHECKING, overload import trio from typing_extensions import assert_type +if TYPE_CHECKING: + from collections.abc import Sequence + async def sleep_sort(values: Sequence[float]) -> list[float]: return [1] @@ -29,7 +32,9 @@ async def foo_overloaded(arg: int | str) -> int | str: v = trio.run( - sleep_sort, (1, 3, 5, 2, 4), clock=trio.testing.MockClock(autojump_threshold=0) + sleep_sort, + (1, 3, 5, 2, 4), + clock=trio.testing.MockClock(autojump_threshold=0), ) assert_type(v, "list[float]") trio.run(sleep_sort, ["hi", "there"]) # type: ignore[arg-type] diff --git a/src/trio/_core/_thread_cache.py b/src/trio/_core/_thread_cache.py index d338ec1ee7..189d5a5836 100644 --- a/src/trio/_core/_thread_cache.py +++ b/src/trio/_core/_thread_cache.py @@ -7,10 +7,13 @@ from functools import partial from itertools import count from threading import Lock, Thread -from typing import Any, Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar import outcome +if TYPE_CHECKING: + from collections.abc import Callable + RetT = TypeVar("RetT") @@ -23,7 +26,9 @@ def _to_os_thread_name(name: str) -> bytes: # called once on import def get_os_thread_name_func() -> Callable[[int | None, str], None] | None: def namefunc( - setname: Callable[[int, bytes], int], ident: int | None, name: str + setname: Callable[[int, bytes], int], + ident: int | None, + name: str, ) -> None: # Thread.ident is None "if it has not been started". Unclear if that can happen # with current usage. @@ -33,7 +38,9 @@ def namefunc( # namefunc on Mac also takes an ident, even if pthread_setname_np doesn't/can't use it # so the caller don't need to care about platform. def darwin_namefunc( - setname: Callable[[bytes], int], ident: int | None, name: str + setname: Callable[[bytes], int], + ident: int | None, + name: str, ) -> None: # I don't know if Mac can rename threads that hasn't been started, but default # to no to be on the safe side. @@ -122,6 +129,8 @@ def darwin_namefunc( class WorkerThread(Generic[RetT]): + __slots__ = ("_default_name", "_job", "_thread", "_thread_cache", "_worker_lock") + def __init__(self, thread_cache: ThreadCache) -> None: self._job: ( tuple[ @@ -203,8 +212,11 @@ def _work(self) -> None: class ThreadCache: + __slots__ = ("_idle_workers",) + def __init__(self) -> None: - self._idle_workers: dict[WorkerThread[Any], None] = {} + # Explicit "Any" not allowed + self._idle_workers: dict[WorkerThread[Any], None] = {} # type: ignore[misc] def start_thread_soon( self, diff --git a/src/trio/_core/_traps.py b/src/trio/_core/_traps.py index 85e6b57306..1ddd5628ba 100644 --- a/src/trio/_core/_traps.py +++ b/src/trio/_core/_traps.py @@ -4,7 +4,9 @@ import enum import types -from typing import TYPE_CHECKING, Any, Callable, NoReturn + +# Jedi gets mad in test_static_tool_sees_class_members if we use collections Callable +from typing import TYPE_CHECKING, Any, Callable, NoReturn, Union, cast import attrs import outcome @@ -12,10 +14,40 @@ from . import _run if TYPE_CHECKING: + from collections.abc import Awaitable, Generator + from typing_extensions import TypeAlias from ._run import Task +RaiseCancelT: TypeAlias = Callable[[], NoReturn] + + +# This class object is used as a singleton. +# Not exported in the trio._core namespace, but imported directly by _run. +class CancelShieldedCheckpoint: + __slots__ = () + + +# Not exported in the trio._core namespace, but imported directly by _run. +@attrs.frozen(slots=False) +class WaitTaskRescheduled: + abort_func: Callable[[RaiseCancelT], Abort] + + +# Not exported in the trio._core namespace, but imported directly by _run. +@attrs.frozen(slots=False) +class PermanentlyDetachCoroutineObject: + final_outcome: outcome.Outcome[object] + + +MessageType: TypeAlias = Union[ + type[CancelShieldedCheckpoint], + WaitTaskRescheduled, + PermanentlyDetachCoroutineObject, + object, +] + # Helper for the bottommost 'yield'. You can't use 'yield' inside an async # function, but you can inside a generator, and if you decorate your generator @@ -25,14 +57,18 @@ # tracking machinery. Since our traps are public APIs, we make them real async # functions, and then this helper takes care of the actual yield: @types.coroutine -def _async_yield(obj: Any) -> Any: # type: ignore[misc] +def _real_async_yield( + obj: MessageType, +) -> Generator[MessageType, None, None]: return (yield obj) -# This class object is used as a singleton. -# Not exported in the trio._core namespace, but imported directly by _run. -class CancelShieldedCheckpoint: - pass +# Real yield value is from trio's main loop, but type checkers can't +# understand that, so we cast it to make type checkers understand. +_async_yield = cast( + "Callable[[MessageType], Awaitable[outcome.Outcome[object]]]", + _real_async_yield, +) async def cancel_shielded_checkpoint() -> None: @@ -66,18 +102,12 @@ class Abort(enum.Enum): FAILED = 2 -# Not exported in the trio._core namespace, but imported directly by _run. -@attrs.frozen(slots=False) -class WaitTaskRescheduled: - abort_func: Callable[[RaiseCancelT], Abort] - - -RaiseCancelT: TypeAlias = Callable[[], NoReturn] - - # Should always return the type a Task "expects", unless you willfully reschedule it # with a bad value. -async def wait_task_rescheduled(abort_func: Callable[[RaiseCancelT], Abort]) -> Any: +# Explicit "Any" is not allowed +async def wait_task_rescheduled( # type: ignore[misc] + abort_func: Callable[[RaiseCancelT], Abort], +) -> Any: """Put the current task to sleep, with cancellation support. This is the lowest-level API for blocking in Trio. Every time a @@ -179,15 +209,9 @@ def abort(inner_raise_cancel): return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap() -# Not exported in the trio._core namespace, but imported directly by _run. -@attrs.frozen(slots=False) -class PermanentlyDetachCoroutineObject: - final_outcome: outcome.Outcome[Any] - - async def permanently_detach_coroutine_object( - final_outcome: outcome.Outcome[Any], -) -> Any: + final_outcome: outcome.Outcome[object], +) -> object: """Permanently detach the current task from the Trio scheduler. Normally, a Trio task doesn't exit until its coroutine object exits. When @@ -213,14 +237,14 @@ async def permanently_detach_coroutine_object( """ if _run.current_task().child_nurseries: raise RuntimeError( - "can't permanently detach a coroutine object with open nurseries" + "can't permanently detach a coroutine object with open nurseries", ) return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome)) async def temporarily_detach_coroutine_object( - abort_func: Callable[[RaiseCancelT], Abort] -) -> Any: + abort_func: Callable[[RaiseCancelT], Abort], +) -> object: """Temporarily detach the current coroutine object from the Trio scheduler. diff --git a/src/trio/_core/_wakeup_socketpair.py b/src/trio/_core/_wakeup_socketpair.py index fb821a23e7..ea4567017f 100644 --- a/src/trio/_core/_wakeup_socketpair.py +++ b/src/trio/_core/_wakeup_socketpair.py @@ -63,7 +63,7 @@ def wakeup_on_signals(self) -> None: "running Trio in guest mode, then this might mean you " "should set host_uses_signal_set_wakeup_fd=True. " "Otherwise, file a bug on Trio and we'll help you figure " - "out what's going on." + "out what's going on.", ), stacklevel=1, ) diff --git a/src/trio/_core/_windows_cffi.py b/src/trio/_core/_windows_cffi.py index 244ea773c5..575fcb5601 100644 --- a/src/trio/_core/_windows_cffi.py +++ b/src/trio/_core/_windows_cffi.py @@ -210,7 +210,7 @@ # cribbed from pywincffi # programmatically strips out those annotations MSDN likes, like _In_ REGEX_SAL_ANNOTATION = re.compile( - r"\b(_In_|_Inout_|_Out_|_Outptr_|_Reserved_)(opt_)?\b" + r"\b(_In_|_Inout_|_Out_|_Outptr_|_Reserved_)(opt_)?\b", ) LIB = REGEX_SAL_ANNOTATION.sub(" ", LIB) @@ -253,7 +253,10 @@ def CreateEventA( ) -> Handle: ... def SetFileCompletionNotificationModes( - self, handle: Handle, flags: CompletionModes, / + self, + handle: Handle, + flags: CompletionModes, + /, ) -> int: ... def PostQueuedCompletionStatus( @@ -392,9 +395,9 @@ class _Overlapped(Protocol): hEvent: Handle -kernel32 = cast(_Kernel32, ffi.dlopen("kernel32.dll")) -ntdll = cast(_Nt, ffi.dlopen("ntdll.dll")) -ws2_32 = cast(_Ws2, ffi.dlopen("ws2_32.dll")) +kernel32 = cast("_Kernel32", ffi.dlopen("kernel32.dll")) +ntdll = cast("_Nt", ffi.dlopen("ntdll.dll")) +ws2_32 = cast("_Ws2", ffi.dlopen("ws2_32.dll")) ################################################################ # Magic numbers diff --git a/src/trio/_dtls.py b/src/trio/_dtls.py index 31f7817e1c..a7709632a4 100644 --- a/src/trio/_dtls.py +++ b/src/trio/_dtls.py @@ -19,12 +19,7 @@ from itertools import count from typing import ( TYPE_CHECKING, - Any, - Awaitable, - Callable, Generic, - Iterable, - Iterator, TypeVar, Union, ) @@ -37,12 +32,14 @@ from ._util import NoPublicConstructor, final if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, Iterable, Iterator from types import TracebackType # See DTLSEndpoint.__init__ for why this is imported here - from OpenSSL import SSL # noqa: TCH004 + from OpenSSL import SSL # noqa: TC004 from typing_extensions import Self, TypeAlias, TypeVarTuple, Unpack + from trio._socket import AddressFormat from trio.socket import SocketType PosArgsT = TypeVarTuple("PosArgsT") @@ -61,7 +58,7 @@ def worst_case_mtu(sock: SocketType) -> int: if sock.family == trio.socket.AF_INET: return 576 - packet_header_overhead(sock) else: - return 1280 - packet_header_overhead(sock) + return 1280 - packet_header_overhead(sock) # TODO: test this line def best_guess_mtu(sock: SocketType) -> int: @@ -225,7 +222,7 @@ def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment: frag_offset_bytes, frag_len_bytes, ) = HANDSHAKE_MESSAGE_HEADER.unpack_from(payload) - except struct.error as exc: + except struct.error as exc: # TODO: test this line raise BadPacket("bad handshake message header") from exc # 'struct' doesn't have built-in support for 24-bit integers, so we # have to do it by hand. These can't fail. @@ -353,7 +350,9 @@ class OpaqueHandshakeMessage: _AnyHandshakeMessage: TypeAlias = Union[ - HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage + HandshakeMessage, + PseudoHandshakeMessage, + OpaqueHandshakeMessage, ] @@ -378,8 +377,10 @@ def decode_volley_trusted( elif record.content_type in (ContentType.change_cipher_spec, ContentType.alert): messages.append( PseudoHandshakeMessage( - record.version, record.content_type, record.payload - ) + record.version, + record.content_type, + record.payload, + ), ) else: assert record.content_type == ContentType.handshake @@ -424,14 +425,14 @@ def encode_volley( for message in messages: if isinstance(message, OpaqueHandshakeMessage): encoded = encode_record(message.record) - if mtu - len(packet) - len(encoded) <= 0: + if mtu - len(packet) - len(encoded) <= 0: # TODO: test this line packets.append(packet) packet = bytearray() packet += encoded assert len(packet) <= mtu elif isinstance(message, PseudoHandshakeMessage): space = mtu - len(packet) - RECORD_HEADER.size - len(message.payload) - if space <= 0: + if space <= 0: # TODO: test this line packets.append(packet) packet = bytearray() packet += RECORD_HEADER.pack( @@ -557,15 +558,18 @@ def _current_cookie_tick() -> int: # Simple deterministic and invertible serializer -- i.e., a useful tool for converting # structured data into something we can cryptographically sign. def _signable(*fields: bytes) -> bytes: - out = [] + out: list[bytes] = [] for field in fields: - out.append(struct.pack("!Q", len(field))) - out.append(field) + out.extend((struct.pack("!Q", len(field)), field)) return b"".join(out) def _make_cookie( - key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes + key: bytes, + salt: bytes, + tick: int, + address: AddressFormat, + client_hello_bits: bytes, ) -> bytes: assert len(salt) == SALT_BYTES assert len(key) == KEY_BYTES @@ -583,7 +587,10 @@ def _make_cookie( def valid_cookie( - key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes + key: bytes, + cookie: bytes, + address: AddressFormat, + client_hello_bits: bytes, ) -> bool: if len(cookie) > SALT_BYTES: salt = cookie[:SALT_BYTES] @@ -592,20 +599,28 @@ def valid_cookie( cur_cookie = _make_cookie(key, salt, tick, address, client_hello_bits) old_cookie = _make_cookie( - key, salt, max(tick - 1, 0), address, client_hello_bits + key, + salt, + max(tick - 1, 0), + address, + client_hello_bits, ) # I doubt using a short-circuiting 'or' here would leak any meaningful # information, but why risk it when '|' is just as easy. return hmac.compare_digest(cookie, cur_cookie) | hmac.compare_digest( - cookie, old_cookie + cookie, + old_cookie, ) else: return False def challenge_for( - key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes + key: bytes, + address: AddressFormat, + epoch_seqno: int, + client_hello_bits: bytes, ) -> bytes: salt = os.urandom(SALT_BYTES) tick = _current_cookie_tick() @@ -641,7 +656,7 @@ def challenge_for( payload = encode_handshake_fragment(hs) packet = encode_record( - Record(ContentType.handshake, ProtocolVersion.DTLS10, epoch_seqno, payload) + Record(ContentType.handshake, ProtocolVersion.DTLS10, epoch_seqno, payload), ) return packet @@ -650,7 +665,7 @@ def challenge_for( class _Queue(Generic[_T]): - def __init__(self, incoming_packets_buffer: int | float): # noqa: PYI041 + def __init__(self, incoming_packets_buffer: int | float) -> None: # noqa: PYI041 self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer) @@ -666,7 +681,9 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: async def handle_client_hello_untrusted( - endpoint: DTLSEndpoint, address: Any, packet: bytes + endpoint: DTLSEndpoint, + address: AddressFormat, + packet: bytes, ) -> None: # it's trivial to write a simple function that directly calls this to # get code coverage, but it should maybe: @@ -686,7 +703,10 @@ async def handle_client_hello_untrusted( if not valid_cookie(endpoint._listening_key, cookie, address, bits): challenge_packet = challenge_for( - endpoint._listening_key, address, epoch_seqno, bits + endpoint._listening_key, + address, + epoch_seqno, + bits, ) try: async with endpoint._send_lock: @@ -714,23 +734,6 @@ async def handle_client_hello_untrusted( # after all. return - # Some old versions of OpenSSL have a bug with memory BIOs, where DTLSv1_listen - # consumes the ClientHello out of the BIO, but then do_handshake expects the - # ClientHello to still be in there (but not the one that ships with Ubuntu - # 20.04). In particular, this is known to affect the OpenSSL v1.1.1 that ships - # with Ubuntu 18.04. To work around this, we deliver a second copy of the - # ClientHello after DTLSv1_listen has completed. This is safe to do - # unconditionally, because on newer versions of OpenSSL, the second ClientHello - # is treated as a duplicate packet, which is a normal thing that can happen over - # UDP. For more details, see: - # - # https://github.com/pyca/pyopenssl/blob/e84e7b57d1838de70ab7a27089fbee78ce0d2106/tests/test_ssl.py#L4226-L4293 - # - # This was fixed in v1.1.1a, and all later versions. So maybe in 2024 or so we - # can delete this. The fix landed in OpenSSL master as 079ef6bd534d2, and then - # was backported to the 1.1.1 branch as d1bfd8076e28. - stream._ssl.bio_write(packet) - # Check if we have an existing association old_stream = endpoint._streams.get(address) if old_stream is not None: @@ -746,7 +749,8 @@ async def handle_client_hello_untrusted( async def dtls_receive_loop( - endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType + endpoint_ref: ReferenceType[DTLSEndpoint], + sock: SocketType, ) -> None: try: while True: @@ -839,7 +843,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): def __init__( self, endpoint: DTLSEndpoint, - peer_address: Any, + peer_address: AddressFormat, ctx: SSL.Context, ) -> None: self.endpoint = endpoint @@ -853,7 +857,7 @@ def __init__( # support and isn't useful anyway -- especially for DTLS where it's equivalent # to just performing a new handshake. ctx.set_options( - SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION # type: ignore[attr-defined] + SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION, # type: ignore[attr-defined] ) self._ssl = SSL.Connection(ctx) self._handshake_mtu = 0 @@ -877,7 +881,7 @@ def _set_replaced(self) -> None: def _check_replaced(self) -> None: if self._replaced: raise trio.BrokenResourceError( - "peer tore down this connection to start a new one" + "peer tore down this connection to start a new one", ) # XX on systems where we can (maybe just Linux?) take advantage of the kernel's PMTU @@ -931,7 +935,8 @@ async def aclose(self) -> None: async def _send_volley(self, volley_messages: list[_AnyHandshakeMessage]) -> None: packets = self._record_encoder.encode_volley( - volley_messages, self._handshake_mtu + volley_messages, + self._handshake_mtu, ) for packet in packets: async with self.endpoint._send_lock: @@ -1034,7 +1039,7 @@ def read_volley() -> list[_AnyHandshakeMessage]: if ( isinstance(maybe_volley[0], PseudoHandshakeMessage) and maybe_volley[0].content_type == ContentType.alert - ): + ): # TODO: test this line # we're sending an alert (e.g. due to a corrupted # packet). We want to send it once, but don't save it to # retransmit -- keep the last volley as the current @@ -1066,7 +1071,8 @@ def read_volley() -> list[_AnyHandshakeMessage]: # PMTU estimate is wrong? Let's try dropping it to the minimum # and hope that helps. self._handshake_mtu = min( - self._handshake_mtu, worst_case_mtu(self.endpoint.socket) + self._handshake_mtu, + worst_case_mtu(self.endpoint.socket), ) async def send(self, data: bytes) -> None: @@ -1082,7 +1088,8 @@ async def send(self, data: bytes) -> None: self._ssl.write(data) async with self.endpoint._send_lock: await self.endpoint.socket.sendto( - _read_loop(self._ssl.bio_read), self.peer_address + _read_loop(self._ssl.bio_read), + self.peer_address, ) async def receive(self) -> bytes: @@ -1212,7 +1219,9 @@ def __init__( # as a peer provides a valid cookie, we can immediately tear down the # old connection. # {remote address: DTLSChannel} - self._streams: WeakValueDictionary[Any, DTLSChannel] = WeakValueDictionary() + self._streams: WeakValueDictionary[AddressFormat, DTLSChannel] = ( + WeakValueDictionary() + ) self._listening_context: SSL.Context | None = None self._listening_key: bytes | None = None self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) @@ -1226,7 +1235,9 @@ def _ensure_receive_loop(self) -> None: # after we send our first packet. if not self._receive_loop_spawned: trio.lowlevel.spawn_system_task( - dtls_receive_loop, weakref.ref(self), self.socket + dtls_receive_loop, + weakref.ref(self), + self.socket, ) self._receive_loop_spawned = True @@ -1315,10 +1326,9 @@ async def handler(dtls_channel): raise trio.BusyResourceError("another task is already listening") try: self.socket.getsockname() - except OSError: - # TODO: Write test that triggers this - raise RuntimeError( # pragma: no cover - "DTLS socket must be bound before it can serve" + except OSError: # TODO: test this line + raise RuntimeError( + "DTLS socket must be bound before it can serve", ) from None self._ensure_receive_loop() # We do cookie verification ourselves, so tell OpenSSL not to worry about it. diff --git a/src/trio/_file_io.py b/src/trio/_file_io.py index ef867243f0..5307fb9425 100644 --- a/src/trio/_file_io.py +++ b/src/trio/_file_io.py @@ -1,6 +1,7 @@ from __future__ import annotations import io +from collections.abc import Callable, Iterable from functools import partial from typing import ( IO, @@ -8,9 +9,7 @@ Any, AnyStr, BinaryIO, - Callable, Generic, - Iterable, TypeVar, Union, overload, @@ -32,6 +31,8 @@ ) from typing_extensions import Literal + from ._sync import CapacityLimiter + # This list is also in the docs, make sure to keep them in sync _FILE_SYNC_ATTRS: set[str] = { "closed", @@ -242,7 +243,10 @@ def __getattr__(self, name: str) -> object: meth = getattr(self._wrapped, name) @async_wraps(self.__class__, self._wrapped.__class__, name) - async def wrapper(*args, **kwargs): + async def wrapper( + *args: Callable[..., T], + **kwargs: object | str | bool | CapacityLimiter | None, + ) -> T: func = partial(meth, *args, **kwargs) return await trio.to_thread.run_sync(func) @@ -445,7 +449,7 @@ async def open_file( newline: str | None = None, closefd: bool = True, opener: _Opener | None = None, -) -> AsyncIOWrapper[Any]: +) -> AsyncIOWrapper[object]: """Asynchronous version of :func:`open`. Returns: @@ -463,12 +467,20 @@ async def open_file( :func:`trio.Path.open` """ - _file = wrap_file( + file_ = wrap_file( await trio.to_thread.run_sync( - io.open, file, mode, buffering, encoding, errors, newline, closefd, opener - ) + io.open, + file, + mode, + buffering, + encoding, + errors, + newline, + closefd, + opener, + ), ) - return _file + return file_ def wrap_file(file: FileT) -> AsyncIOWrapper[FileT]: @@ -495,7 +507,7 @@ def has(attr: str) -> bool: if not (has("close") and (has("read") or has("write"))): raise TypeError( f"{file} does not implement required duck-file methods: " - "close and (read or write)" + "close and (read or write)", ) return AsyncIOWrapper(file) diff --git a/src/trio/_highlevel_generic.py b/src/trio/_highlevel_generic.py index 88a86318a3..041a684c62 100644 --- a/src/trio/_highlevel_generic.py +++ b/src/trio/_highlevel_generic.py @@ -53,7 +53,7 @@ def _is_halfclosable(stream: SendStream) -> TypeGuard[HalfCloseableStream]: @final -@attrs.define(eq=False, hash=False, slots=False) +@attrs.define(eq=False, slots=False) class StapledStream( HalfCloseableStream, Generic[SendStreamT, ReceiveStreamT], diff --git a/src/trio/_highlevel_open_tcp_listeners.py b/src/trio/_highlevel_open_tcp_listeners.py index 80555be33e..023b2b240f 100644 --- a/src/trio/_highlevel_open_tcp_listeners.py +++ b/src/trio/_highlevel_open_tcp_listeners.py @@ -42,7 +42,7 @@ # backlog just causes it to be silently truncated to the configured maximum, # so this is unnecessary -- we can just pass in "infinity" and get the maximum # that way. (Verified on Windows, Linux, macOS using -# notes-to-self/measure-listen-backlog.py) +# https://github.com/python-trio/trio/wiki/notes-to-self#measure-listen-backlogpy def _compute_backlog(backlog: int | None) -> int: # Many systems (Linux, BSDs, ...) store the backlog in a uint16 and are # missing overflow protection, so we apply our own overflow protection. @@ -112,7 +112,10 @@ async def open_tcp_listeners( computed_backlog = _compute_backlog(backlog) addresses = await tsocket.getaddrinfo( - host, port, type=tsocket.SOCK_STREAM, flags=tsocket.AI_PASSIVE + host, + port, + type=tsocket.SOCK_STREAM, + flags=tsocket.AI_PASSIVE, ) listeners = [] @@ -159,7 +162,8 @@ async def open_tcp_listeners( "socket that that address could use" ) raise OSError(errno.EAFNOSUPPORT, msg) from ExceptionGroup( - msg, unsupported_address_families + msg, + unsupported_address_families, ) return listeners @@ -240,5 +244,8 @@ async def serve_tcp( """ listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog) await trio.serve_listeners( - handler, listeners, handler_nursery=handler_nursery, task_status=task_status + handler, + listeners, + handler_nursery=handler_nursery, + task_status=task_status, ) diff --git a/src/trio/_highlevel_open_tcp_stream.py b/src/trio/_highlevel_open_tcp_stream.py index d5c83da7c0..d4ec98355f 100644 --- a/src/trio/_highlevel_open_tcp_stream.py +++ b/src/trio/_highlevel_open_tcp_stream.py @@ -11,6 +11,8 @@ from collections.abc import Generator from socket import AddressFamily, SocketKind + from trio._socket import AddressFormat + if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup @@ -132,16 +134,9 @@ def close_all() -> Generator[set[SocketType], None, None]: raise BaseExceptionGroup("", errs) -def reorder_for_rfc_6555_section_5_4( - targets: list[ - tuple[ - AddressFamily, - SocketKind, - int, - str, - Any, - ] - ] +# Explicit "Any" is not allowed +def reorder_for_rfc_6555_section_5_4( # type: ignore[misc] + targets: list[tuple[AddressFamily, SocketKind, int, str, Any]], ) -> None: # RFC 6555 section 5.4 says that if getaddrinfo returns multiple address # families (e.g. IPv4 and IPv6), then you should make sure that your first @@ -302,7 +297,7 @@ async def open_tcp_stream( # face of crash or cancellation async def attempt_connect( socket_args: tuple[AddressFamily, SocketKind, int], - sockaddr: Any, + sockaddr: AddressFormat, attempt_failed: trio.Event, ) -> None: nonlocal winning_socket @@ -346,14 +341,16 @@ async def attempt_connect( # better job of it because it knows the remote IP/port. with suppress(OSError, AttributeError): sock.setsockopt( - trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT, 1 + trio.socket.IPPROTO_IP, + trio.socket.IP_BIND_ADDRESS_NO_PORT, + 1, ) try: await sock.bind((local_address, 0)) except OSError: raise OSError( f"local_address={local_address!r} is incompatible " - f"with remote address {sockaddr!r}" + f"with remote address {sockaddr!r}", ) from None await sock.connect(sockaddr) @@ -382,7 +379,9 @@ async def attempt_connect( # workaround to check types until typing of nursery.start_soon improved if TYPE_CHECKING: await attempt_connect( - (address_family, socket_type, proto), addr, attempt_failed + (address_family, socket_type, proto), + addr, + attempt_failed, ) nursery.start_soon( diff --git a/src/trio/_highlevel_serve_listeners.py b/src/trio/_highlevel_serve_listeners.py index ec5a0efb3c..9b17f8d538 100644 --- a/src/trio/_highlevel_serve_listeners.py +++ b/src/trio/_highlevel_serve_listeners.py @@ -3,7 +3,8 @@ import errno import logging import os -from typing import Any, Awaitable, Callable, NoReturn, TypeVar +from collections.abc import Awaitable, Callable +from typing import Any, NoReturn, TypeVar import trio @@ -24,7 +25,8 @@ StreamT = TypeVar("StreamT", bound=trio.abc.AsyncResource) -ListenerT = TypeVar("ListenerT", bound=trio.abc.Listener[Any]) +# Explicit "Any" is not allowed +ListenerT = TypeVar("ListenerT", bound=trio.abc.Listener[Any]) # type: ignore[misc] Handler = Callable[[StreamT], Awaitable[object]] @@ -66,7 +68,8 @@ async def _serve_one_listener( # https://github.com/python/typing/issues/548 -async def serve_listeners( +# Explicit "Any" is not allowed +async def serve_listeners( # type: ignore[misc] handler: Handler[StreamT], listeners: list[ListenerT], *, @@ -143,5 +146,5 @@ async def serve_listeners( task_status.started(listeners) raise AssertionError( - "_serve_one_listener should never complete" + "_serve_one_listener should never complete", ) # pragma: no cover diff --git a/src/trio/_highlevel_socket.py b/src/trio/_highlevel_socket.py index 901e22f345..c04e66e1bf 100644 --- a/src/trio/_highlevel_socket.py +++ b/src/trio/_highlevel_socket.py @@ -68,7 +68,7 @@ class SocketStream(HalfCloseableStream): """ - def __init__(self, socket: SocketType): + def __init__(self, socket: SocketType) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -76,7 +76,7 @@ def __init__(self, socket: SocketType): self.socket = socket self._send_conflict_detector = ConflictDetector( - "another task is currently sending data on this SocketStream" + "another task is currently sending data on this SocketStream", ) # Socket defaults: @@ -167,12 +167,12 @@ def setsockopt( if length is None: if value is None: raise TypeError( - "invalid value for argument 'value', must not be None when specifying length" + "invalid value for argument 'value', must not be None when specifying length", ) return self.socket.setsockopt(level, option, value) if value is not None: raise TypeError( - f"invalid value for argument 'value': {value!r}, must be None when specifying optlen" + f"invalid value for argument 'value': {value!r}, must be None when specifying optlen", ) return self.socket.setsockopt(level, option, value, length) @@ -364,7 +364,7 @@ class SocketListener(Listener[SocketStream]): """ - def __init__(self, socket: SocketType): + def __init__(self, socket: SocketType) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: diff --git a/src/trio/_highlevel_ssl_helpers.py b/src/trio/_highlevel_ssl_helpers.py index 03562c9edb..1239491a43 100644 --- a/src/trio/_highlevel_ssl_helpers.py +++ b/src/trio/_highlevel_ssl_helpers.py @@ -59,7 +59,9 @@ async def open_ssl_over_tcp_stream( """ tcp_stream = await trio.open_tcp_stream( - host, port, happy_eyeballs_delay=happy_eyeballs_delay + host, + port, + happy_eyeballs_delay=happy_eyeballs_delay, ) if ssl_context is None: ssl_context = ssl.create_default_context() @@ -68,7 +70,10 @@ async def open_ssl_over_tcp_stream( ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF return trio.SSLStream( - tcp_stream, ssl_context, server_hostname=host, https_compatible=https_compatible + tcp_stream, + ssl_context, + server_hostname=host, + https_compatible=https_compatible, ) @@ -168,5 +173,8 @@ async def serve_ssl_over_tcp( backlog=backlog, ) await trio.serve_listeners( - handler, listeners, handler_nursery=handler_nursery, task_status=task_status + handler, + listeners, + handler_nursery=handler_nursery, + task_status=task_status, ) diff --git a/src/trio/_path.py b/src/trio/_path.py index b9b5749c35..a58136b75b 100644 --- a/src/trio/_path.py +++ b/src/trio/_path.py @@ -30,8 +30,9 @@ T = TypeVar("T") -def _wraps_async( - wrapped: Callable[..., Any] +# Explicit .../"Any" is not allowed +def _wraps_async( # type: ignore[misc] + wrapped: Callable[..., object], ) -> Callable[[Callable[P, T]], Callable[P, Awaitable[T]]]: def decorator(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: @@ -222,8 +223,7 @@ def __repr__(self) -> str: group = _wrap_method(pathlib.Path.group) if sys.platform != "win32" or sys.version_info >= (3, 12): is_mount = _wrap_method(pathlib.Path.is_mount) - if sys.version_info >= (3, 9): - readlink = _wrap_method_path(pathlib.Path.readlink) + readlink = _wrap_method_path(pathlib.Path.readlink) rename = _wrap_method_path(pathlib.Path.rename) replace = _wrap_method_path(pathlib.Path.replace) resolve = _wrap_method_path(pathlib.Path.resolve) diff --git a/src/trio/_repl.py b/src/trio/_repl.py index 73f050140e..f9efcc0017 100644 --- a/src/trio/_repl.py +++ b/src/trio/_repl.py @@ -22,7 +22,7 @@ class TrioInteractiveConsole(InteractiveConsole): # we make the type more specific on our subclass locals: dict[str, object] - def __init__(self, repl_locals: dict[str, object] | None = None): + def __init__(self, repl_locals: dict[str, object] | None = None) -> None: super().__init__(locals=repl_locals) self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT diff --git a/src/trio/_signals.py b/src/trio/_signals.py index f4d912808f..729c48ad4e 100644 --- a/src/trio/_signals.py +++ b/src/trio/_signals.py @@ -7,7 +7,7 @@ import trio -from ._util import ConflictDetector, is_main_thread, signal_raise +from ._util import ConflictDetector, is_main_thread if TYPE_CHECKING: from collections.abc import AsyncIterator, Callable, Generator, Iterable @@ -72,13 +72,13 @@ def __init__(self) -> None: self._pending: OrderedDict[int, None] = OrderedDict() self._lot = trio.lowlevel.ParkingLot() self._conflict_detector = ConflictDetector( - "only one task can iterate on a signal receiver at a time" + "only one task can iterate on a signal receiver at a time", ) self._closed = False def _add(self, signum: int) -> None: if self._closed: - signal_raise(signum) + signal.raise_signal(signum) else: self._pending[signum] = None self._lot.unpark() @@ -95,7 +95,7 @@ def deliver_next() -> None: if self._pending: signum, _ = self._pending.popitem(last=False) try: - signal_raise(signum) + signal.raise_signal(signum) finally: deliver_next() @@ -170,7 +170,7 @@ def open_signal_receiver( if not is_main_thread(): raise RuntimeError( "Sorry, open_signal_receiver is only possible when running in " - "Python interpreter's main thread" + "Python interpreter's main thread", ) token = trio.lowlevel.current_trio_token() queue = SignalReceiver() diff --git a/src/trio/_socket.py b/src/trio/_socket.py index 0a3bd1cba1..4dde512985 100644 --- a/src/trio/_socket.py +++ b/src/trio/_socket.py @@ -9,9 +9,6 @@ from typing import ( TYPE_CHECKING, Any, - Awaitable, - Callable, - Literal, SupportsIndex, TypeVar, Union, @@ -26,7 +23,7 @@ from . import _core if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Awaitable, Callable, Iterable from types import TracebackType from typing_extensions import Buffer, Concatenate, ParamSpec, Self, TypeAlias @@ -49,7 +46,8 @@ # most users, so currently we just specify it as `Any`. Otherwise we would write: # `AddressFormat = TypeVar("AddressFormat")` # but instead we simply do: -AddressFormat: TypeAlias = Any +# Explicit "Any" is not allowed +AddressFormat: TypeAlias = Any # type: ignore[misc] # Usage: @@ -62,8 +60,9 @@ # class _try_sync: def __init__( - self, blocking_exc_override: Callable[[BaseException], bool] | None = None - ): + self, + blocking_exc_override: Callable[[BaseException], bool] | None = None, + ) -> None: self._blocking_exc_override = blocking_exc_override def _is_blocking_io_error(self, exc: BaseException) -> bool: @@ -179,7 +178,11 @@ async def getaddrinfo( flags: int = 0, ) -> list[ tuple[ - AddressFamily, SocketKind, int, str, tuple[str, int] | tuple[str, int, int, int] + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], ] ]: """Look up a numeric address given a name. @@ -210,7 +213,12 @@ def numeric_only_failure(exc: BaseException) -> bool: async with _try_sync(numeric_only_failure): return _stdlib_socket.getaddrinfo( - host, port, family, type, proto, flags | _NUMERIC_ONLY + host, + port, + family, + type, + proto, + flags | _NUMERIC_ONLY, ) # That failed; it's a real hostname. We better use a thread. # @@ -245,7 +253,8 @@ def numeric_only_failure(exc: BaseException) -> bool: async def getnameinfo( - sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, ) -> tuple[str, str]: """Look up a name given a numeric address. @@ -261,7 +270,10 @@ async def getnameinfo( return await hr.getnameinfo(sockaddr, flags) else: return await trio.to_thread.run_sync( - _stdlib_socket.getnameinfo, sockaddr, flags, abandon_on_cancel=True + _stdlib_socket.getnameinfo, + sockaddr, + flags, + abandon_on_cancel=True, ) @@ -272,7 +284,9 @@ async def getprotobyname(name: str) -> int: """ return await trio.to_thread.run_sync( - _stdlib_socket.getprotobyname, name, abandon_on_cancel=True + _stdlib_socket.getprotobyname, + name, + abandon_on_cancel=True, ) @@ -319,7 +333,7 @@ def fromshare(info: bytes) -> SocketType: TypeT: TypeAlias = int FamilyDefault = _stdlib_socket.AF_INET else: - FamilyDefault: Literal[None] = None + FamilyDefault: None = None FamilyT: TypeAlias = Union[int, AddressFamily, None] TypeT: TypeAlias = Union[_stdlib_socket.socket, int] @@ -357,7 +371,10 @@ def socket( return sf.socket(family, type, proto) else: family, type, proto = _sniff_sockopts_for_fileno( # noqa: A001 - family, type, proto, fileno + family, + type, + proto, + fileno, ) stdlib_socket = _stdlib_socket.socket(family, type, proto, fileno) return from_stdlib_socket(stdlib_socket) @@ -457,7 +474,7 @@ async def _resolve_address_nocp( ipv6_v6only: bool | int, address: AddressFormat, local: bool, -) -> Any: +) -> AddressFormat: # Do some pre-checking (or exit early for non-IP sockets) if family == _stdlib_socket.AF_INET: if not isinstance(address, tuple) or not len(address) == 2: @@ -465,11 +482,11 @@ async def _resolve_address_nocp( elif family == _stdlib_socket.AF_INET6: if not isinstance(address, tuple) or not 2 <= len(address) <= 4: raise ValueError( - "address should be a (host, port, [flowinfo, [scopeid]]) tuple" + "address should be a (host, port, [flowinfo, [scopeid]]) tuple", ) elif hasattr(_stdlib_socket, "AF_UNIX") and family == _stdlib_socket.AF_UNIX: # unwrap path-likes - assert isinstance(address, (str, bytes)) + assert isinstance(address, (str, bytes, os.PathLike)) return os.fspath(address) else: return address @@ -531,7 +548,7 @@ def __init__(self) -> None: if type(self) is SocketType: raise TypeError( "SocketType is an abstract class; use trio.socket.socket if you " - "want to construct a socket object" + "want to construct a socket object", ) def detach(self) -> int: @@ -547,27 +564,33 @@ def getsockname(self) -> AddressFormat: raise NotImplementedError @overload - def getsockopt(self, /, level: int, optname: int) -> int: ... + def getsockopt(self, level: int, optname: int) -> int: ... @overload - def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ... + def getsockopt(self, level: int, optname: int, buflen: int) -> bytes: ... def getsockopt( - self, /, level: int, optname: int, buflen: int | None = None + self, + level: int, + optname: int, + buflen: int | None = None, ) -> int | bytes: raise NotImplementedError @overload - def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ... + def setsockopt(self, level: int, optname: int, value: int | Buffer) -> None: ... @overload def setsockopt( - self, /, level: int, optname: int, value: None, optlen: int + self, + level: int, + optname: int, + value: None, + optlen: int, ) -> None: ... def setsockopt( self, - /, level: int, optname: int, value: int | Buffer | None, @@ -575,7 +598,7 @@ def setsockopt( ) -> None: raise NotImplementedError - def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + def listen(self, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: raise NotImplementedError def get_inheritable(self) -> bool: @@ -588,7 +611,7 @@ def set_inheritable(self, inheritable: bool) -> None: not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") ): - def share(self, /, process_id: int) -> bytes: + def share(self, process_id: int) -> bytes: raise NotImplementedError def __enter__(self) -> Self: @@ -648,24 +671,32 @@ async def accept(self) -> tuple[SocketType, AddressFormat]: async def connect(self, address: AddressFormat) -> None: raise NotImplementedError - # argument names with __ used because of typeshed, see comment for recv in _SocketType - def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: + def recv(self, buflen: int, flags: int = 0, /) -> Awaitable[bytes]: raise NotImplementedError def recv_into( - __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + self, + buffer: Buffer, + nbytes: int = 0, + flags: int = 0, ) -> Awaitable[int]: raise NotImplementedError # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] def recvfrom( - __self, __bufsize: int, __flags: int = 0 + self, + bufsize: int, + flags: int = 0, + /, ) -> Awaitable[tuple[bytes, AddressFormat]]: raise NotImplementedError # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] def recvfrom_into( - __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + self, + buffer: Buffer, + nbytes: int = 0, + flags: int = 0, ) -> Awaitable[tuple[int, AddressFormat]]: raise NotImplementedError @@ -674,11 +705,12 @@ def recvfrom_into( ): def recvmsg( - __self, - __bufsize: int, - __ancbufsize: int = 0, - __flags: int = 0, - ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: + self, + bufsize: int, + ancbufsize: int = 0, + flags: int = 0, + /, + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, object]]: raise NotImplementedError if sys.platform != "win32" or ( @@ -686,30 +718,35 @@ def recvmsg( ): def recvmsg_into( - __self, - __buffers: Iterable[Buffer], - __ancbufsize: int = 0, - __flags: int = 0, - ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: + self, + buffers: Iterable[Buffer], + ancbufsize: int = 0, + flags: int = 0, + /, + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, object]]: raise NotImplementedError - def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: + def send(self, bytes: Buffer, flags: int = 0, /) -> Awaitable[int]: raise NotImplementedError @overload async def sendto( - self, __data: Buffer, __address: tuple[object, ...] | str | Buffer + self, + data: Buffer, + address: tuple[object, ...] | str | Buffer, + /, ) -> int: ... @overload async def sendto( self, - __data: Buffer, - __flags: int, - __address: tuple[object, ...] | str | Buffer, + data: Buffer, + flags: int, + address: tuple[object, ...] | str | Buffer, + /, ) -> int: ... - async def sendto(self, *args: Any) -> int: + async def sendto(self, *args: object) -> int: raise NotImplementedError if sys.platform != "win32" or ( @@ -719,10 +756,11 @@ async def sendto(self, *args: Any) -> int: @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) async def sendmsg( self, - __buffers: Iterable[Buffer], - __ancdata: Iterable[tuple[int, int, Buffer]] = (), - __flags: int = 0, - __address: AddressFormat | None = None, + buffers: Iterable[Buffer], + ancdata: Iterable[tuple[int, int, Buffer]] = (), + flags: int = 0, + address: AddressFormat | None = None, + /, ) -> int: raise NotImplementedError @@ -743,12 +781,12 @@ async def sendmsg( class _SocketType(SocketType): - def __init__(self, sock: _stdlib_socket.socket): + def __init__(self, sock: _stdlib_socket.socket) -> None: if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we # certainly don't want to blindly wrap one of those. raise TypeError( - f"expected object of type 'socket.socket', not '{type(sock).__name__}'" + f"expected object of type 'socket.socket', not '{type(sock).__name__}'", ) self._sock = sock self._sock.setblocking(False) @@ -772,29 +810,35 @@ def getsockname(self) -> AddressFormat: return self._sock.getsockname() @overload - def getsockopt(self, /, level: int, optname: int) -> int: ... + def getsockopt(self, level: int, optname: int) -> int: ... @overload - def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ... + def getsockopt(self, level: int, optname: int, buflen: int) -> bytes: ... def getsockopt( - self, /, level: int, optname: int, buflen: int | None = None + self, + level: int, + optname: int, + buflen: int | None = None, ) -> int | bytes: if buflen is None: return self._sock.getsockopt(level, optname) return self._sock.getsockopt(level, optname, buflen) @overload - def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ... + def setsockopt(self, level: int, optname: int, value: int | Buffer) -> None: ... @overload def setsockopt( - self, /, level: int, optname: int, value: None, optlen: int + self, + level: int, + optname: int, + value: None, + optlen: int, ) -> None: ... def setsockopt( self, - /, level: int, optname: int, value: int | Buffer | None, @@ -803,19 +847,19 @@ def setsockopt( if optlen is None: if value is None: raise TypeError( - "invalid value for argument 'value', must not be None when specifying optlen" + "invalid value for argument 'value', must not be None when specifying optlen", ) return self._sock.setsockopt(level, optname, value) if value is not None: raise TypeError( - f"invalid value for argument 'value': {value!r}, must be None when specifying optlen" + f"invalid value for argument 'value': {value!r}, must be None when specifying optlen", ) # Note: PyPy may crash here due to setsockopt only supporting # four parameters. return self._sock.setsockopt(level, optname, value, optlen) - def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + def listen(self, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: return self._sock.listen(backlog) def get_inheritable(self) -> bool: @@ -828,7 +872,7 @@ def set_inheritable(self, inheritable: bool) -> None: not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") ): - def share(self, /, process_id: int) -> bytes: + def share(self, process_id: int) -> bytes: return self._sock.share(process_id) def __enter__(self) -> Self: @@ -915,7 +959,8 @@ async def _resolve_address_nocp( ) -> AddressFormat: if self.family == _stdlib_socket.AF_INET6: ipv6_v6only = self._sock.getsockopt( - _stdlib_socket.IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY + _stdlib_socket.IPPROTO_IPV6, + _stdlib_socket.IPV6_V6ONLY, ) else: ipv6_v6only = False @@ -977,7 +1022,8 @@ async def _nonblocking_helper( ################################################################ _accept = _make_simple_sock_method_wrapper( - _stdlib_socket.socket.accept, _core.wait_readable + _stdlib_socket.socket.accept, + _core.wait_readable, ) async def accept(self) -> tuple[SocketType, AddressFormat]: @@ -1069,13 +1115,11 @@ async def connect(self, address: AddressFormat) -> None: # complain about AmbiguousType if TYPE_CHECKING: - def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: ... + def recv(self, buflen: int, flags: int = 0, /) -> Awaitable[bytes]: ... - # _make_simple_sock_method_wrapper is typed, so this checks that the above is correct - # this requires that we refrain from using `/` to specify pos-only - # args, or mypy thinks the signature differs from typeshed. recv = _make_simple_sock_method_wrapper( - _stdlib_socket.socket.recv, _core.wait_readable + _stdlib_socket.socket.recv, + _core.wait_readable, ) ################################################################ @@ -1085,11 +1129,16 @@ def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: ... if TYPE_CHECKING: def recv_into( - __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + self, + /, + buffer: Buffer, + nbytes: int = 0, + flags: int = 0, ) -> Awaitable[int]: ... recv_into = _make_simple_sock_method_wrapper( - _stdlib_socket.socket.recv_into, _core.wait_readable + _stdlib_socket.socket.recv_into, + _core.wait_readable, ) ################################################################ @@ -1099,11 +1148,15 @@ def recv_into( if TYPE_CHECKING: # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] def recvfrom( - __self, __bufsize: int, __flags: int = 0 + self, + bufsize: int, + flags: int = 0, + /, ) -> Awaitable[tuple[bytes, AddressFormat]]: ... recvfrom = _make_simple_sock_method_wrapper( - _stdlib_socket.socket.recvfrom, _core.wait_readable + _stdlib_socket.socket.recvfrom, + _core.wait_readable, ) ################################################################ @@ -1113,11 +1166,16 @@ def recvfrom( if TYPE_CHECKING: # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] def recvfrom_into( - __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + self, + /, + buffer: Buffer, + nbytes: int = 0, + flags: int = 0, ) -> Awaitable[tuple[int, AddressFormat]]: ... recvfrom_into = _make_simple_sock_method_wrapper( - _stdlib_socket.socket.recvfrom_into, _core.wait_readable + _stdlib_socket.socket.recvfrom_into, + _core.wait_readable, ) ################################################################ @@ -1130,11 +1188,17 @@ def recvfrom_into( if TYPE_CHECKING: def recvmsg( - __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0 - ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: ... + self, + bufsize: int, + ancbufsize: int = 0, + flags: int = 0, + /, + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, object]]: ... recvmsg = _make_simple_sock_method_wrapper( - _stdlib_socket.socket.recvmsg, _core.wait_readable, maybe_avail=True + _stdlib_socket.socket.recvmsg, + _core.wait_readable, + maybe_avail=True, ) ################################################################ @@ -1147,14 +1211,17 @@ def recvmsg( if TYPE_CHECKING: def recvmsg_into( - __self, - __buffers: Iterable[Buffer], - __ancbufsize: int = 0, - __flags: int = 0, - ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: ... + self, + buffers: Iterable[Buffer], + ancbufsize: int = 0, + flags: int = 0, + /, + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, object]]: ... recvmsg_into = _make_simple_sock_method_wrapper( - _stdlib_socket.socket.recvmsg_into, _core.wait_readable, maybe_avail=True + _stdlib_socket.socket.recvmsg_into, + _core.wait_readable, + maybe_avail=True, ) ################################################################ @@ -1163,10 +1230,11 @@ def recvmsg_into( if TYPE_CHECKING: - def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: ... + def send(self, bytes: Buffer, flags: int = 0, /) -> Awaitable[int]: ... send = _make_simple_sock_method_wrapper( - _stdlib_socket.socket.send, _core.wait_writable + _stdlib_socket.socket.send, + _core.wait_writable, ) ################################################################ @@ -1175,16 +1243,23 @@ def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: ... @overload async def sendto( - self, __data: Buffer, __address: tuple[object, ...] | str | Buffer + self, + data: Buffer, + address: tuple[object, ...] | str | Buffer, + /, ) -> int: ... @overload async def sendto( - self, __data: Buffer, __flags: int, __address: tuple[object, ...] | str | Buffer + self, + data: Buffer, + flags: int, + address: tuple[object, ...] | str | Buffer, + /, ) -> int: ... - @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] - async def sendto(self, *args: Any) -> int: + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) + async def sendto(self, *args: object) -> int: """Similar to :meth:`socket.socket.sendto`, but async.""" # args is: data[, flags], address # and kwargs are not accepted @@ -1209,10 +1284,11 @@ async def sendto(self, *args: Any) -> int: @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) async def sendmsg( self, - __buffers: Iterable[Buffer], - __ancdata: Iterable[tuple[int, int, Buffer]] = (), - __flags: int = 0, - __address: AddressFormat | None = None, + buffers: Iterable[Buffer], + ancdata: Iterable[tuple[int, int, Buffer]] = (), + flags: int = 0, + address: AddressFormat | None = None, + /, ) -> int: """Similar to :meth:`socket.socket.sendmsg`, but async. @@ -1220,15 +1296,15 @@ async def sendmsg( available. """ - if __address is not None: - __address = await self._resolve_address_nocp(__address, local=False) + if address is not None: + address = await self._resolve_address_nocp(address, local=False) return await self._nonblocking_helper( _core.wait_writable, _stdlib_socket.socket.sendmsg, - __buffers, - __ancdata, - __flags, - __address, + buffers, + ancdata, + flags, + address, ) ################################################################ diff --git a/src/trio/_ssl.py b/src/trio/_ssl.py index 5bc37cf7dc..0a0419fbcb 100644 --- a/src/trio/_ssl.py +++ b/src/trio/_ssl.py @@ -16,6 +16,10 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from typing_extensions import TypeVarTuple, Unpack + + Ts = TypeVarTuple("Ts") + # General theory of operation: # # We implement an API that closely mirrors the stdlib ssl module's blocking @@ -219,7 +223,13 @@ class NeedHandshakeError(Exception): class _Once: - def __init__(self, afn: Callable[..., Awaitable[object]], *args: object) -> None: + __slots__ = ("_afn", "_args", "_done", "started") + + def __init__( + self, + afn: Callable[[*Ts], Awaitable[object]], + *args: Unpack[Ts], + ) -> None: self._afn = afn self._args = args self.started = False @@ -376,10 +386,10 @@ def __init__( # multiple concurrent calls to send_all/wait_send_all_might_not_block # or to receive_some. self._outer_send_conflict_detector = ConflictDetector( - "another task is currently sending data on this SSLStream" + "another task is currently sending data on this SSLStream", ) self._outer_recv_conflict_detector = ConflictDetector( - "another task is currently receiving data on this SSLStream" + "another task is currently receiving data on this SSLStream", ) self._estimated_receive_size = STARTING_RECEIVE_SIZE @@ -413,7 +423,11 @@ def __init__( "version", } - def __getattr__(self, name: str) -> Any: + # Explicit "Any" is not allowed + def __getattr__( # type: ignore[misc] + self, + name: str, + ) -> Any: if name in self._forwarded: if name in self._after_handshake and not self._handshook.done: raise NeedHandshakeError(f"call do_handshake() before calling {name!r}") @@ -447,8 +461,8 @@ def _check_status(self) -> None: # too. async def _retry( self, - fn: Callable[..., T], - *args: object, + fn: Callable[[*Ts], T], + *args: Unpack[Ts], ignore_want_read: bool = False, is_handshake: bool = False, ) -> T | None: @@ -615,7 +629,8 @@ async def _retry( self._incoming.write_eof() else: self._estimated_receive_size = max( - self._estimated_receive_size, len(data) + self._estimated_receive_size, + len(data), ) self._incoming.write(data) self._inner_recv_count += 1 @@ -893,6 +908,10 @@ async def wait_send_all_might_not_block(self) -> None: await self.transport_stream.wait_send_all_might_not_block() +# this is necessary for Sphinx, see also `_abc.py` +SSLStream.__module__ = SSLStream.__module__.replace("._ssl", "") + + @final class SSLListener(Listener[SSLStream[T_Stream]]): """A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers. diff --git a/src/trio/_subprocess.py b/src/trio/_subprocess.py index 553e3d4885..ff5cc8d393 100644 --- a/src/trio/_subprocess.py +++ b/src/trio/_subprocess.py @@ -32,12 +32,7 @@ # Sphinx cannot parse the stringified version -if sys.version_info >= (3, 9): - StrOrBytesPath: TypeAlias = Union[str, bytes, os.PathLike[str], os.PathLike[bytes]] -else: - StrOrBytesPath: TypeAlias = Union[ - str, bytes, "os.PathLike[str]", "os.PathLike[bytes]" - ] +StrOrBytesPath: TypeAlias = Union[str, bytes, os.PathLike[str], os.PathLike[bytes]] # Linux-specific, but has complex lifetime management stuff so we hard-code it @@ -238,7 +233,7 @@ async def wait(self) -> int: if self.poll() is None: if self._pidfd is not None: with contextlib.suppress( - ClosedResourceError + ClosedResourceError, ): # something else (probably a call to poll) already closed the pidfd await trio.lowlevel.wait_readable(self._pidfd.fileno()) else: @@ -301,7 +296,7 @@ def kill(self) -> None: async def _open_process( - command: list[str] | str, + command: StrOrBytesPath | Sequence[StrOrBytesPath], *, stdin: int | HasFileno | None = None, stdout: int | HasFileno | None = None, @@ -326,13 +321,14 @@ async def _open_process( want. Args: - command (list or str): The command to run. Typically this is a - sequence of strings such as ``['ls', '-l', 'directory with spaces']``, - where the first element names the executable to invoke and the other - elements specify its arguments. With ``shell=True`` in the - ``**options``, or on Windows, ``command`` may alternatively - be a string, which will be parsed following platform-dependent - :ref:`quoting rules `. + command: The command to run. Typically this is a sequence of strings or + bytes such as ``['ls', '-l', 'directory with spaces']``, where the + first element names the executable to invoke and the other elements + specify its arguments. With ``shell=True`` in the ``**options``, or on + Windows, ``command`` can be a string or bytes, which will be parsed + following platform-dependent :ref:`quoting rules + `. In all cases ``command`` can be a path or a + sequence of paths. stdin: Specifies what the child process's standard input stream should connect to: output written by the parent (``subprocess.PIPE``), nothing (``subprocess.DEVNULL``), @@ -362,19 +358,20 @@ async def _open_process( if options.get(key): raise TypeError( "trio.Process only supports communicating over " - f"unbuffered byte streams; the '{key}' option is not supported" + f"unbuffered byte streams; the '{key}' option is not supported", ) if os.name == "posix": - if isinstance(command, str) and not options.get("shell"): + # TODO: how do paths and sequences thereof play with `shell=True`? + if isinstance(command, (str, bytes)) and not options.get("shell"): raise TypeError( - "command must be a sequence (not a string) if shell=False " - "on UNIX systems" + "command must be a sequence (not a string or bytes) if " + "shell=False on UNIX systems", ) - if not isinstance(command, str) and options.get("shell"): + if not isinstance(command, (str, bytes)) and options.get("shell"): raise TypeError( - "command must be a string (not a sequence) if shell=True " - "on UNIX systems" + "command must be a string or bytes (not a sequence) if " + "shell=True on UNIX systems", ) trio_stdin: ClosableSendStream | None = None @@ -417,7 +414,7 @@ async def _open_process( stdout=stdout, stderr=stderr, **options, - ) + ), ) # We did not fail, so dismiss the stack for the trio ends cleanup_on_fail.pop_all() @@ -425,7 +422,8 @@ async def _open_process( return Process._create(popen, trio_stdin, trio_stdout, trio_stderr) -async def _windows_deliver_cancel(p: Process) -> None: +# async function missing await +async def _windows_deliver_cancel(p: Process) -> None: # noqa: RUF029 try: p.terminate() except OSError as exc: @@ -443,7 +441,7 @@ async def _posix_deliver_cancel(p: Process) -> None: RuntimeWarning( f"process {p!r} ignored SIGTERM for 5 seconds. " "(Maybe you should pass a custom deliver_cancel?) " - "Trying SIGKILL." + "Trying SIGKILL.", ), stacklevel=1, ) @@ -671,17 +669,17 @@ async def my_deliver_cancel(process): raise ValueError( "stdout=subprocess.PIPE is only valid with nursery.start, " "since that's the only way to access the pipe; use nursery.start " - "or pass the data you want to write directly" + "or pass the data you want to write directly", ) if options.get("stdout") is subprocess.PIPE: raise ValueError( "stdout=subprocess.PIPE is only valid with nursery.start, " - "since that's the only way to access the pipe" + "since that's the only way to access the pipe", ) if options.get("stderr") is subprocess.PIPE: raise ValueError( "stderr=subprocess.PIPE is only valid with nursery.start, " - "since that's the only way to access the pipe" + "since that's the only way to access the pipe", ) if isinstance(stdin, (bytes, bytearray, memoryview)): input_ = stdin @@ -768,7 +766,10 @@ async def killer() -> None: if proc.returncode and check: raise subprocess.CalledProcessError( - proc.returncode, proc.args, output=stdout, stderr=stderr + proc.returncode, + proc.args, + output=stdout, + stderr=stderr, ) else: assert proc.returncode is not None diff --git a/src/trio/_subprocess_platform/__init__.py b/src/trio/_subprocess_platform/__init__.py index d74cd462a0..daa28d8cd2 100644 --- a/src/trio/_subprocess_platform/__init__.py +++ b/src/trio/_subprocess_platform/__init__.py @@ -8,7 +8,7 @@ import trio from .. import _core, _subprocess -from .._abc import ReceiveStream, SendStream # noqa: TCH001 +from .._abc import ReceiveStream, SendStream # noqa: TC001 _wait_child_exiting_error: ImportError | None = None _create_child_pipe_error: ImportError | None = None @@ -85,11 +85,11 @@ def create_pipe_from_child_output() -> tuple[ClosableReceiveStream, int]: elif os.name == "posix": - def create_pipe_to_child_stdin(): + def create_pipe_to_child_stdin() -> tuple[trio.lowlevel.FdStream, int]: rfd, wfd = os.pipe() return trio.lowlevel.FdStream(wfd), rfd - def create_pipe_from_child_output(): + def create_pipe_from_child_output() -> tuple[trio.lowlevel.FdStream, int]: rfd, wfd = os.pipe() return trio.lowlevel.FdStream(rfd), wfd @@ -106,12 +106,12 @@ def create_pipe_from_child_output(): from .._windows_pipes import PipeReceiveStream, PipeSendStream - def create_pipe_to_child_stdin(): + def create_pipe_to_child_stdin() -> tuple[PipeSendStream, int]: # for stdin, we want the write end (our end) to use overlapped I/O rh, wh = windows_pipe(overlapped=(False, True)) return PipeSendStream(wh), msvcrt.open_osfhandle(rh, os.O_RDONLY) - def create_pipe_from_child_output(): + def create_pipe_from_child_output() -> tuple[PipeReceiveStream, int]: # for stdout/err, it's the read end that's overlapped rh, wh = windows_pipe(overlapped=(True, False)) return PipeReceiveStream(rh), msvcrt.open_osfhandle(wh, 0) diff --git a/src/trio/_subprocess_platform/kqueue.py b/src/trio/_subprocess_platform/kqueue.py index fcf72650ee..2283bb5360 100644 --- a/src/trio/_subprocess_platform/kqueue.py +++ b/src/trio/_subprocess_platform/kqueue.py @@ -21,7 +21,10 @@ async def wait_child_exiting(process: _subprocess.Process) -> None: def make_event(flags: int) -> select.kevent: return select.kevent( - process.pid, filter=select.KQ_FILTER_PROC, flags=flags, fflags=KQ_NOTE_EXIT + process.pid, + filter=select.KQ_FILTER_PROC, + flags=flags, + fflags=KQ_NOTE_EXIT, ) try: diff --git a/src/trio/_subprocess_platform/waitid.py b/src/trio/_subprocess_platform/waitid.py index 44c8261074..ebf83b4802 100644 --- a/src/trio/_subprocess_platform/waitid.py +++ b/src/trio/_subprocess_platform/waitid.py @@ -40,7 +40,7 @@ def sync_wait_reapable(pid: int) -> None: int pad[26]; } siginfo_t; int waitid(int idtype, int id, siginfo_t* result, int options); -""" +""", ) waitid_cffi = waitid_ffi.dlopen(None).waitid # type: ignore[attr-defined] @@ -79,7 +79,10 @@ async def _waitid_system_task(pid: int, event: Event) -> None: try: await to_thread_run_sync( - sync_wait_reapable, pid, abandon_on_cancel=True, limiter=waitid_limiter + sync_wait_reapable, + pid, + abandon_on_cancel=True, + limiter=waitid_limiter, ) except OSError: # If waitid fails, waitpid will fail too, so it still makes diff --git a/src/trio/_sync.py b/src/trio/_sync.py index 6e62eceeff..ca373922b0 100644 --- a/src/trio/_sync.py +++ b/src/trio/_sync.py @@ -8,7 +8,14 @@ import trio from . import _core -from ._core import Abort, ParkingLot, RaiseCancelT, enable_ki_protection +from ._core import ( + Abort, + ParkingLot, + RaiseCancelT, + add_parking_lot_breaker, + enable_ki_protection, + remove_parking_lot_breaker, +) from ._util import final if TYPE_CHECKING: @@ -33,7 +40,7 @@ class EventStatistics: @final -@attrs.define(repr=False, eq=False, hash=False) +@attrs.define(repr=False, eq=False) class Event: """A waitable boolean value useful for inter-task synchronization, inspired by :class:`threading.Event`. @@ -213,7 +220,7 @@ class CapacityLimiter(AsyncContextManagerMixin): """ # total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing - def __init__(self, total_tokens: int | float): # noqa: PYI041 + def __init__(self, total_tokens: int | float) -> None: # noqa: PYI041 self._lot = ParkingLot() self._borrowers: set[Task | object] = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of @@ -297,7 +304,7 @@ def acquire_on_behalf_of_nowait(self, borrower: Task | object) -> None: """ if borrower in self._borrowers: raise RuntimeError( - "this borrower is already holding one of this CapacityLimiter's tokens" + "this borrower is already holding one of this CapacityLimiter's tokens", ) if len(self._borrowers) < self._total_tokens and not self._lot: self._borrowers.add(borrower) @@ -366,7 +373,7 @@ def release_on_behalf_of(self, borrower: Task | object) -> None: """ if borrower not in self._borrowers: raise RuntimeError( - "this borrower isn't holding any of this CapacityLimiter's tokens" + "this borrower isn't holding any of this CapacityLimiter's tokens", ) self._borrowers.remove(borrower) self._wake_waiters() @@ -426,7 +433,7 @@ class Semaphore(AsyncContextManagerMixin): """ - def __init__(self, initial_value: int, *, max_value: int | None = None): + def __init__(self, initial_value: int, *, max_value: int | None = None) -> None: if not isinstance(initial_value, int): raise TypeError("initial_value must be an int") if initial_value < 0: @@ -538,7 +545,7 @@ class LockStatistics: tasks_waiting: int -@attrs.define(eq=False, hash=False, repr=False, slots=False) +@attrs.define(eq=False, repr=False, slots=False) class _LockImpl(AsyncContextManagerMixin): _lot: ParkingLot = attrs.field(factory=ParkingLot, init=False) _owner: Task | None = attrs.field(default=None, init=False) @@ -576,20 +583,30 @@ def acquire_nowait(self) -> None: elif self._owner is None and not self._lot: # No-one owns it self._owner = task + add_parking_lot_breaker(task, self._lot) else: raise trio.WouldBlock @enable_ki_protection async def acquire(self) -> None: - """Acquire the lock, blocking if necessary.""" + """Acquire the lock, blocking if necessary. + + Raises: + BrokenResourceError: if the owner of the lock exits without releasing. + """ await trio.lowlevel.checkpoint_if_cancelled() try: self.acquire_nowait() except trio.WouldBlock: - # NOTE: it's important that the contended acquire path is just - # "_lot.park()", because that's how Condition.wait() acquires the - # lock as well. - await self._lot.park() + try: + # NOTE: it's important that the contended acquire path is just + # "_lot.park()", because that's how Condition.wait() acquires the + # lock as well. + await self._lot.park() + except trio.BrokenResourceError: + raise trio.BrokenResourceError( + f"Owner of this lock exited without releasing: {self._owner}", + ) from None else: await trio.lowlevel.cancel_shielded_checkpoint() @@ -604,8 +621,10 @@ def release(self) -> None: task = trio.lowlevel.current_task() if task is not self._owner: raise RuntimeError("can't release a Lock you don't own") + remove_parking_lot_breaker(self._owner, self._lot) if self._lot: (self._owner,) = self._lot.unpark(count=1) + add_parking_lot_breaker(self._owner, self._lot) else: self._owner = None @@ -622,7 +641,9 @@ def statistics(self) -> LockStatistics: """ return LockStatistics( - locked=self.locked(), owner=self._owner, tasks_waiting=len(self._lot) + locked=self.locked(), + owner=self._owner, + tasks_waiting=len(self._lot), ) @@ -738,7 +759,7 @@ class Condition(AsyncContextManagerMixin): """ - def __init__(self, lock: Lock | None = None): + def __init__(self, lock: Lock | None = None) -> None: if lock is None: lock = Lock() if type(lock) is not Lock: @@ -765,7 +786,11 @@ def acquire_nowait(self) -> None: return self._lock.acquire_nowait() async def acquire(self) -> None: - """Acquire the underlying lock, blocking if necessary.""" + """Acquire the underlying lock, blocking if necessary. + + Raises: + BrokenResourceError: if the owner of the underlying lock exits without releasing. + """ await self._lock.acquire() def release(self) -> None: @@ -794,6 +819,7 @@ async def wait(self) -> None: Raises: RuntimeError: if the calling task does not hold the lock. + BrokenResourceError: if the owner of the lock exits without releasing, when attempting to re-acquire. """ if trio.lowlevel.current_task() is not self._lock._owner: @@ -845,5 +871,6 @@ def statistics(self) -> ConditionStatistics: """ return ConditionStatistics( - tasks_waiting=len(self._lot), lock_statistics=self._lock.statistics() + tasks_waiting=len(self._lot), + lock_statistics=self._lock.statistics(), ) diff --git a/src/trio/_tests/_check_type_completeness.json b/src/trio/_tests/_check_type_completeness.json index 0bbd47fada..72d981f89c 100644 --- a/src/trio/_tests/_check_type_completeness.json +++ b/src/trio/_tests/_check_type_completeness.json @@ -20,13 +20,10 @@ "all": [ "No docstring found for class \"trio.MemoryReceiveChannel\"", "No docstring found for class \"trio._channel.MemoryReceiveChannel\"", - "No docstring found for function \"trio._channel.MemoryReceiveChannel.statistics\"", - "No docstring found for class \"trio._channel.MemoryChannelStats\"", - "No docstring found for function \"trio._channel.MemoryReceiveChannel.aclose\"", + "No docstring found for class \"trio.MemoryChannelStatistics\"", + "No docstring found for class \"trio._channel.MemoryChannelStatistics\"", "No docstring found for class \"trio.MemorySendChannel\"", "No docstring found for class \"trio._channel.MemorySendChannel\"", - "No docstring found for function \"trio._channel.MemorySendChannel.statistics\"", - "No docstring found for function \"trio._channel.MemorySendChannel.aclose\"", "No docstring found for class \"trio._core._run.Task\"", "No docstring found for class \"trio._socket.SocketType\"", "No docstring found for function \"trio._highlevel_socket.SocketStream.send_all\"", @@ -43,7 +40,6 @@ "No docstring found for class \"trio._core._local.RunVarToken\"", "No docstring found for class \"trio.lowlevel.RunVarToken\"", "No docstring found for class \"trio.lowlevel.Task\"", - "No docstring found for class \"trio._core._ki.KIProtectionSignature\"", "No docstring found for class \"trio.socket.SocketType\"", "No docstring found for class \"trio.socket.gaierror\"", "No docstring found for class \"trio.socket.herror\"", diff --git a/src/trio/_tests/check_type_completeness.py b/src/trio/_tests/check_type_completeness.py index fa6ace074f..ea53aa0e37 100755 --- a/src/trio/_tests/check_type_completeness.py +++ b/src/trio/_tests/check_type_completeness.py @@ -32,7 +32,7 @@ def run_pyright(platform: str) -> subprocess.CompletedProcess[bytes]: "pyright", # Specify a platform and version to keep imported modules consistent. f"--pythonplatform={platform}", - "--pythonversion=3.8", + "--pythonversion=3.9", "--verifytypes=trio", "--outputjson", "--ignoreexternal", @@ -111,7 +111,9 @@ def has_docstring_at_runtime(name: str) -> bool: def check_type( - platform: str, full_diagnostics_file: Path | None, expected_errors: list[object] + platform: str, + full_diagnostics_file: Path | None, + expected_errors: list[object], ) -> list[object]: # convince isort we use the trio import assert trio @@ -144,7 +146,7 @@ def check_type( if message.startswith("No docstring found for"): continue if message.startswith( - "Type is missing type annotation and could be inferred differently by type checkers" + "Type is missing type annotation and could be inferred differently by type checkers", ): continue diff --git a/src/trio/_tests/module_with_deprecations.py b/src/trio/_tests/module_with_deprecations.py index 73184d11e8..afe4187191 100644 --- a/src/trio/_tests/module_with_deprecations.py +++ b/src/trio/_tests/module_with_deprecations.py @@ -16,6 +16,9 @@ __deprecated_attributes__ = { "dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1), "dep2": _deprecate.DeprecatedAttribute( - "value2", "1.2", issue=1, instead="instead-string" + "value2", + "1.2", + issue=1, + instead="instead-string", ), } diff --git a/src/trio/_tests/test_channel.py b/src/trio/_tests/test_channel.py index 1271f6b765..001f07568c 100644 --- a/src/trio/_tests/test_channel.py +++ b/src/trio/_tests/test_channel.py @@ -107,7 +107,8 @@ async def consumer(receive_channel: trio.MemoryReceiveChannel[int], i: int) -> N async def test_close_basics() -> None: async def send_block( - s: trio.MemorySendChannel[None], expect: type[BaseException] + s: trio.MemorySendChannel[None], + expect: type[BaseException], ) -> None: with pytest.raises(expect): await s.send(None) @@ -149,7 +150,7 @@ async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None: with pytest.raises(trio.ClosedResourceError): await r.receive() - s2, r2 = open_memory_channel[int](0) + _s2, r2 = open_memory_channel[int](0) async with trio.open_nursery() as nursery: nursery.start_soon(receive_block, r2) await wait_all_tasks_blocked() @@ -164,7 +165,8 @@ async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None: async def test_close_sync() -> None: async def send_block( - s: trio.MemorySendChannel[None], expect: type[BaseException] + s: trio.MemorySendChannel[None], + expect: type[BaseException], ) -> None: with pytest.raises(expect): await s.send(None) diff --git a/src/trio/_tests/test_contextvars.py b/src/trio/_tests/test_contextvars.py index ae0c25f876..63965e1e12 100644 --- a/src/trio/_tests/test_contextvars.py +++ b/src/trio/_tests/test_contextvars.py @@ -5,7 +5,7 @@ from .. import _core trio_testing_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar( - "trio_testing_contextvar" + "trio_testing_contextvar", ) diff --git a/src/trio/_tests/test_deprecate.py b/src/trio/_tests/test_deprecate.py index fa5d7cbfef..1da1549d38 100644 --- a/src/trio/_tests/test_deprecate.py +++ b/src/trio/_tests/test_deprecate.py @@ -161,7 +161,10 @@ def new_hotness_method(self) -> str: return "new hotness method" old_hotness_method = deprecated_alias( - "Alias.old_hotness_method", new_hotness_method, "3.21", issue=1 + "Alias.old_hotness_method", + new_hotness_method, + "3.21", + issue=1, ) @@ -272,5 +275,9 @@ def test_warning_class() -> None: with pytest.warns(TrioDeprecationWarning): warn_deprecated( - "foo", "bar", issue=None, instead=None, use_triodeprecationwarning=True + "foo", + "bar", + issue=None, + instead=None, + use_triodeprecationwarning=True, ) diff --git a/src/trio/_tests/test_deprecate_strict_exception_groups_false.py b/src/trio/_tests/test_deprecate_strict_exception_groups_false.py index 317672bf23..1b02c9ee73 100644 --- a/src/trio/_tests/test_deprecate_strict_exception_groups_false.py +++ b/src/trio/_tests/test_deprecate_strict_exception_groups_false.py @@ -1,4 +1,4 @@ -from typing import Awaitable, Callable +from collections.abc import Awaitable, Callable import pytest @@ -7,7 +7,8 @@ async def test_deprecation_warning_open_nursery() -> None: with pytest.warns( - trio.TrioDeprecationWarning, match="strict_exception_groups=False" + trio.TrioDeprecationWarning, + match="strict_exception_groups=False", ) as record: async with trio.open_nursery(strict_exception_groups=False): ... @@ -31,9 +32,10 @@ async def foo_loose_nursery() -> None: async with trio.open_nursery(strict_exception_groups=False): ... - def helper(fun: Callable[..., Awaitable[None]], num: int) -> None: + def helper(fun: Callable[[], Awaitable[None]], num: int) -> None: with pytest.warns( - trio.TrioDeprecationWarning, match="strict_exception_groups=False" + trio.TrioDeprecationWarning, + match="strict_exception_groups=False", ) as record: trio.run(fun, strict_exception_groups=False) assert len(record) == num @@ -52,7 +54,8 @@ async def trio_return(in_host: object) -> str: return "ok" with pytest.warns( - trio.TrioDeprecationWarning, match="strict_exception_groups=False" + trio.TrioDeprecationWarning, + match="strict_exception_groups=False", ) as record: trivial_guest_run( trio_return, diff --git a/src/trio/_tests/test_dtls.py b/src/trio/_tests/test_dtls.py index d14edae25c..141e891586 100644 --- a/src/trio/_tests/test_dtls.py +++ b/src/trio/_tests/test_dtls.py @@ -38,7 +38,9 @@ parametrize_ipv6 = pytest.mark.parametrize( - "ipv6", [False, pytest.param(True, marks=binds_ipv6)], ids=["ipv4", "ipv6"] + "ipv6", + [False, pytest.param(True, marks=binds_ipv6)], + ids=["ipv4", "ipv6"], ) @@ -51,7 +53,10 @@ def endpoint(**kwargs: int | bool) -> DTLSEndpoint: @asynccontextmanager async def dtls_echo_server( - *, autocancel: bool = True, mtu: int | None = None, ipv6: bool = False + *, + autocancel: bool = True, + mtu: int | None = None, + ipv6: bool = False, ) -> AsyncGenerator[tuple[DTLSEndpoint, tuple[str, int]], None]: with endpoint(ipv6=ipv6) as server: localhost = "::1" if ipv6 else "127.0.0.1" @@ -62,7 +67,7 @@ async def echo_handler(dtls_channel: DTLSChannel) -> None: print( "echo handler started: " f"server {dtls_channel.endpoint.socket.getsockname()!r} " - f"client {dtls_channel.peer_address!r}" + f"client {dtls_channel.peer_address!r}", ) if mtu is not None: dtls_channel.set_ciphertext_mtu(mtu) @@ -70,7 +75,9 @@ async def echo_handler(dtls_channel: DTLSChannel) -> None: print("server starting do_handshake") await dtls_channel.do_handshake() print("server finished do_handshake") - async for packet in dtls_channel: + # no branch for leaving this for loop because we only leave + # a channel by cancellation. + async for packet in dtls_channel: # pragma: no branch print(f"echoing {packet!r} -> {dtls_channel.peer_address!r}") await dtls_channel.send(packet) except trio.BrokenResourceError: # pragma: no cover @@ -86,7 +93,7 @@ async def echo_handler(dtls_channel: DTLSChannel) -> None: @parametrize_ipv6 async def test_smoke(ipv6: bool) -> None: - async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address): + async with dtls_echo_server(ipv6=ipv6) as (_server_endpoint, address): with endpoint(ipv6=ipv6) as client_endpoint: client_channel = client_endpoint.connect(address, client_ctx) with pytest.raises(trio.NeedHandshakeError): @@ -99,7 +106,8 @@ async def test_smoke(ipv6: bool) -> None: assert await client_channel.receive() == b"goodbye" with pytest.raises( - ValueError, match="^openssl doesn't support sending empty DTLS packets$" + ValueError, + match="^openssl doesn't support sending empty DTLS packets$", ): await client_channel.send(b"") @@ -165,7 +173,7 @@ async def route_packet(packet: UDPPacket) -> None: assert op == "deliver" print( f"{packet.source} -> {packet.destination}: delivered" - f" {packet.payload.hex()}" + f" {packet.payload.hex()}", ) fn.deliver_packet(packet) break @@ -225,7 +233,8 @@ async def handler(channel: DTLSChannel) -> None: await server_nursery.start(server_endpoint.serve, server_ctx, handler) client = client_endpoint.connect( - server_endpoint.socket.getsockname(), client_ctx + server_endpoint.socket.getsockname(), + client_ctx, ) async with trio.open_nursery() as nursery: nursery.start_soon(client.send, b"from client") @@ -253,7 +262,7 @@ async def test_channel_closing() -> None: async def test_serve_exits_cleanly_on_close() -> None: - async with dtls_echo_server(autocancel=False) as (server_endpoint, address): + async with dtls_echo_server(autocancel=False) as (server_endpoint, _address): server_endpoint.close() # Testing that the nursery exits even without being cancelled # close is idempotent @@ -372,9 +381,9 @@ async def test_server_socket_doesnt_crash_on_garbage( frag_offset=0, frag_len=10, frag=bytes(10), - ) + ), ), - ) + ), ) client_hello_extended = client_hello + b"\x00" @@ -397,9 +406,9 @@ async def test_server_socket_doesnt_crash_on_garbage( frag_offset=0, frag_len=10, frag=bytes(10), - ) + ), ), - ) + ), ) client_hello_trailing_data_in_record = encode_record( @@ -415,10 +424,10 @@ async def test_server_socket_doesnt_crash_on_garbage( frag_offset=0, frag_len=10, frag=bytes(10), - ) + ), ) + b"\x00", - ) + ), ) handshake_empty = encode_record( @@ -427,7 +436,7 @@ async def test_server_socket_doesnt_crash_on_garbage( version=ProtocolVersion.DTLS10, epoch_seqno=0, payload=b"", - ) + ), ) client_hello_truncated_in_cookie = encode_record( @@ -436,7 +445,7 @@ async def test_server_socket_doesnt_crash_on_garbage( version=ProtocolVersion.DTLS10, epoch_seqno=0, payload=bytes(2 + 32 + 1) + b"\xff", - ) + ), ) async with dtls_echo_server() as (_, address): @@ -620,7 +629,8 @@ async def connecter() -> None: # notices the timeout has expired blackholed = False await server_endpoint.socket.sendto( - b"xxx", client_endpoint.socket.getsockname() + b"xxx", + client_endpoint.socket.getsockname(), ) # now the client task should finish connecting and exit cleanly @@ -671,7 +681,7 @@ def route_packet(packet: UDPPacket) -> None: fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet - async with dtls_echo_server(mtu=MTU) as (server, address): + async with dtls_echo_server(mtu=MTU) as (_server, address): with endpoint() as client: channel = client.connect(address, client_ctx) channel.set_ciphertext_mtu(MTU) @@ -682,7 +692,8 @@ def route_packet(packet: UDPPacket) -> None: @parametrize_ipv6 async def test_handshake_handles_minimum_network_mtu( - ipv6: bool, autojump_clock: trio.abc.Clock + ipv6: bool, + autojump_clock: trio.abc.Clock, ) -> None: # Fake network that has the minimum allowable MTU for whatever protocol we're using. fn = FakeNet() diff --git a/src/trio/_tests/test_exports.py b/src/trio/_tests/test_exports.py index 32a2666e48..f89d4105e6 100644 --- a/src/trio/_tests/test_exports.py +++ b/src/trio/_tests/test_exports.py @@ -19,11 +19,10 @@ import trio import trio.testing -from trio._tests.pytest_plugin import skip_if_optional_else_raise +from trio._tests.pytest_plugin import RUN_SLOW, skip_if_optional_else_raise from .. import _core, _util from .._core._tests.tutil import slow -from .pytest_plugin import RUN_SLOW if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -55,7 +54,7 @@ def _ensure_mypy_cache_updated() -> None: "--no-error-summary", "-c", "import trio", - ] + ], ) assert not result[1] # stderr assert not result[0] # stdout @@ -72,7 +71,8 @@ def test_core_is_properly_reexported() -> None: found = 0 for source in sources: if symbol in dir(source) and getattr(source, symbol) is getattr( - _core, symbol + _core, + symbol, ): found += 1 print(symbol, found) @@ -116,7 +116,7 @@ def iter_modules( # they might be using a newer version of Python with additional symbols which # won't be reflected in trio.socket, and this shouldn't cause downstream test # runs to start failing. -@pytest.mark.redistributors_should_skip() +@pytest.mark.redistributors_should_skip # Static analysis tools often have trouble with alpha releases, where Python's # internals are in flux, grammar may not have settled down, etc. @pytest.mark.skipif( @@ -172,8 +172,6 @@ def no_underscores(symbols: Iterable[str]) -> set[str]: elif tool == "mypy": if not RUN_SLOW: # pragma: no cover pytest.skip("use --run-slow to check against mypy") - if sys.implementation.name != "cpython": - pytest.skip("mypy not installed in tests on pypy") cache = Path.cwd() / ".mypy_cache" @@ -244,7 +242,7 @@ def no_underscores(symbols: Iterable[str]) -> set[str]: # modules, instead of once per class. @slow # see comment on test_static_tool_sees_all_symbols -@pytest.mark.redistributors_should_skip() +@pytest.mark.redistributors_should_skip # Static analysis tools often have trouble with alpha releases, where Python's # internals are in flux, grammar may not have settled down, etc. @pytest.mark.skipif( @@ -254,7 +252,9 @@ def no_underscores(symbols: Iterable[str]) -> set[str]: @pytest.mark.parametrize("module_name", PUBLIC_MODULE_NAMES) @pytest.mark.parametrize("tool", ["jedi", "mypy"]) def test_static_tool_sees_class_members( - tool: str, module_name: str, tmp_path: Path + tool: str, + module_name: str, + tmp_path: Path, ) -> None: module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)] @@ -266,10 +266,10 @@ def no_hidden(symbols: Iterable[str]) -> set[str]: if (not symbol.startswith("_")) or symbol.startswith("__") } - if tool == "mypy": - if sys.implementation.name != "cpython": - pytest.skip("mypy not installed in tests on pypy") + if tool == "jedi" and sys.implementation.name != "cpython": + pytest.skip("jedi does not support pypy") + if tool == "mypy": cache = Path.cwd() / ".mypy_cache" _ensure_mypy_cache_updated() @@ -376,7 +376,7 @@ def lookup_symbol(symbol: str) -> dict[str, str]: skip_if_optional_else_raise(error) script = jedi.Script( - f"from {module_name} import {class_name}; {class_name}." + f"from {module_name} import {class_name}; {class_name}.", ) completions = script.complete() static_names = no_hidden(c.name for c in completions) - ignore_names @@ -384,16 +384,20 @@ def lookup_symbol(symbol: str) -> dict[str, str]: elif tool == "mypy": # load the cached type information cached_type_info = cache_json["names"][class_name] - if "node" not in cached_type_info: - cached_type_info = lookup_symbol(cached_type_info["cross_ref"]) + assert ( + "node" not in cached_type_info + ), "previously this was an 'if' but it seems it's no longer possible for this cache to contain 'node', if this assert raises for you please let us know!" + cached_type_info = lookup_symbol(cached_type_info["cross_ref"]) assert "node" in cached_type_info node = cached_type_info["node"] - static_names = no_hidden(k for k in node["names"] if not k.startswith(".")) + static_names = no_hidden( + k for k in node.get("names", ()) if not k.startswith(".") + ) for symbol in node["mro"][1:]: node = lookup_symbol(symbol)["node"] static_names |= no_hidden( - k for k in node["names"] if not k.startswith(".") + k for k in node.get("names", ()) if not k.startswith(".") ) static_names -= ignore_names @@ -571,3 +575,37 @@ def test_classes_are_final() -> None: continue assert class_is_final(class_) + + +# Plugin might not be running, especially if running from an installed version. +@pytest.mark.skipif( + not hasattr(attrs.field, "trio_modded"), + reason="Pytest plugin not installed.", +) +def test_pyright_recognizes_init_attributes() -> None: + """Check whether we provide `alias` for all underscore prefixed attributes. + + Attrs always sets the `alias` attribute on fields, so a pytest plugin is used + to monkeypatch `field()` to record whether an alias was defined in the metadata. + See `_trio_check_attrs_aliases`. + """ + for module in PUBLIC_MODULES: + for class_ in module.__dict__.values(): + if not attrs.has(class_): + continue + if isinstance(class_, _util.NoPublicConstructor): + continue + + attributes = [ + attr + for attr in attrs.fields(class_) + if attr.init + if attr.alias + not in ( + attr.name, + # trio_original_args may not be present in autoattribs + attr.metadata.get("trio_original_args", {}).get("alias"), + ) + ] + + assert attributes == [], class_ diff --git a/src/trio/_tests/test_fakenet.py b/src/trio/_tests/test_fakenet.py index bde6db0191..7a3c328e81 100644 --- a/src/trio/_tests/test_fakenet.py +++ b/src/trio/_tests/test_fakenet.py @@ -34,14 +34,16 @@ async def test_basic_udp() -> None: assert port != 0 with pytest.raises( - OSError, match=r"^\[\w+ \d+\] Invalid argument$" + OSError, + match=r"^\[\w+ \d+\] Invalid argument$", ) as exc: # Cannot rebind. await s1.bind(("192.0.2.1", 0)) assert exc.value.errno == errno.EINVAL # Cannot bind multiple sockets to the same address with pytest.raises( - OSError, match=r"^\[\w+ \d+\] (Address (already )?in use|Unknown error)$" + OSError, + match=r"^\[\w+ \d+\] (Address (already )?in use|Unknown error)$", ) as exc: await s2.bind(("127.0.0.1", port)) assert exc.value.errno == errno.EADDRINUSE @@ -62,7 +64,7 @@ async def test_msg_trunc() -> None: s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) await s1.bind(("127.0.0.1", 0)) await s2.sendto(b"xyz", s1.getsockname()) - data, addr = await s1.recvfrom(10) + await s1.recvfrom(10) async def test_recv_methods() -> None: @@ -131,7 +133,8 @@ async def test_recv_methods() -> None: @pytest.mark.skipif( - sys.platform == "win32", reason="functions not in socket on windows" + sys.platform == "win32", + reason="functions not in socket on windows", ) async def test_nonwindows_functionality() -> None: # mypy doesn't support a good way of aborting typechecking on different platforms @@ -180,13 +183,15 @@ async def test_nonwindows_functionality() -> None: assert addr == s1.getsockname() with pytest.raises( - AttributeError, match="^'FakeSocket' object has no attribute 'share'$" + AttributeError, + match="^'FakeSocket' object has no attribute 'share'$", ): await s1.share(0) # type: ignore[attr-defined] @pytest.mark.skipif( - sys.platform != "win32", reason="windows-specific fakesocket testing" + sys.platform != "win32", + reason="windows-specific fakesocket testing", ) async def test_windows_functionality() -> None: # mypy doesn't support a good way of aborting typechecking on different platforms @@ -196,11 +201,13 @@ async def test_windows_functionality() -> None: s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) await s1.bind(("127.0.0.1", 0)) with pytest.raises( - AttributeError, match="^'FakeSocket' object has no attribute 'sendmsg'$" + AttributeError, + match="^'FakeSocket' object has no attribute 'sendmsg'$", ): await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) # type: ignore[attr-defined] with pytest.raises( - AttributeError, match="^'FakeSocket' object has no attribute 'recvmsg'$" + AttributeError, + match="^'FakeSocket' object has no attribute 'recvmsg'$", ): s2.recvmsg(0) # type: ignore[attr-defined] with pytest.raises( @@ -224,28 +231,33 @@ async def test_not_implemented_functions() -> None: # getsockopt with pytest.raises( - OSError, match=r"^FakeNet doesn't implement getsockopt\(\d, \d\)$" + OSError, + match=r"^FakeNet doesn't implement getsockopt\(\d, \d\)$", ): s1.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) # setsockopt with pytest.raises( - NotImplementedError, match="^FakeNet always has IPV6_V6ONLY=True$" + NotImplementedError, + match="^FakeNet always has IPV6_V6ONLY=True$", ): s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False) with pytest.raises( - OSError, match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$" + OSError, + match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$", ): s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True) with pytest.raises( - OSError, match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$" + OSError, + match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$", ): s1.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # set_inheritable s1.set_inheritable(False) with pytest.raises( - NotImplementedError, match="^FakeNet can't make inheritable sockets$" + NotImplementedError, + match="^FakeNet can't make inheritable sockets$", ): s1.set_inheritable(True) @@ -274,7 +286,7 @@ async def test_init() -> None: with pytest.raises( NotImplementedError, match=re.escape( - f"FakeNet doesn't (yet) support type={trio.socket.SOCK_STREAM}" + f"FakeNet doesn't (yet) support type={trio.socket.SOCK_STREAM}", ), ): s1 = trio.socket.socket() diff --git a/src/trio/_tests/test_file_io.py b/src/trio/_tests/test_file_io.py index cd02d4f768..390a81ce61 100644 --- a/src/trio/_tests/test_file_io.py +++ b/src/trio/_tests/test_file_io.py @@ -59,13 +59,15 @@ def write(self) -> None: # pragma: no cover def test_wrapped_property( - async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock + async_file: AsyncIOWrapper[mock.Mock], + wrapped: mock.Mock, ) -> None: assert async_file.wrapped is wrapped def test_dir_matches_wrapped( - async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock + async_file: AsyncIOWrapper[mock.Mock], + wrapped: mock.Mock, ) -> None: attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS) @@ -132,7 +134,8 @@ def test_type_stubs_match_lists() -> None: def test_sync_attrs_forwarded( - async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock + async_file: AsyncIOWrapper[mock.Mock], + wrapped: mock.Mock, ) -> None: for attr_name in _FILE_SYNC_ATTRS: if attr_name not in dir(async_file): @@ -142,7 +145,8 @@ def test_sync_attrs_forwarded( def test_sync_attrs_match_wrapper( - async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock + async_file: AsyncIOWrapper[mock.Mock], + wrapped: mock.Mock, ) -> None: for attr_name in _FILE_SYNC_ATTRS: if attr_name in dir(async_file): @@ -174,7 +178,8 @@ def test_async_methods_signature(async_file: AsyncIOWrapper[mock.Mock]) -> None: async def test_async_methods_wrap( - async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock + async_file: AsyncIOWrapper[mock.Mock], + wrapped: mock.Mock, ) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name not in dir(async_file): @@ -186,15 +191,17 @@ async def test_async_methods_wrap( value = await meth(sentinel.argument, keyword=sentinel.keyword) wrapped_meth.assert_called_once_with( - sentinel.argument, keyword=sentinel.keyword + sentinel.argument, + keyword=sentinel.keyword, ) assert value == wrapped_meth() wrapped.reset_mock() -async def test_async_methods_match_wrapper( - async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock +def test_async_methods_match_wrapper( + async_file: AsyncIOWrapper[mock.Mock], + wrapped: mock.Mock, ) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name in dir(async_file): diff --git a/src/trio/_tests/test_highlevel_open_tcp_listeners.py b/src/trio/_tests/test_highlevel_open_tcp_listeners.py index 3196c9e533..e78e4414d2 100644 --- a/src/trio/_tests/test_highlevel_open_tcp_listeners.py +++ b/src/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -4,7 +4,7 @@ import socket as stdlib_socket import sys from socket import AddressFamily, SocketKind -from typing import TYPE_CHECKING, Any, Sequence, overload +from typing import TYPE_CHECKING, cast, overload import attrs import pytest @@ -26,8 +26,12 @@ from exceptiongroup import BaseExceptionGroup if TYPE_CHECKING: + from collections.abc import Sequence + from typing_extensions import Buffer + from trio._socket import AddressFormat + async def test_open_tcp_listeners_basic() -> None: listeners = await open_tcp_listeners(0) @@ -160,7 +164,11 @@ def getsockopt(self, /, level: int, optname: int) -> int: ... def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ... def getsockopt( - self, /, level: int, optname: int, buflen: int | None = None + self, + /, + level: int, + optname: int, + buflen: int | None = None, ) -> int | bytes: if (level, optname) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN): return True @@ -171,7 +179,12 @@ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: @overload def setsockopt( - self, /, level: int, optname: int, value: None, optlen: int + self, + /, + level: int, + optname: int, + value: None, + optlen: int, ) -> None: ... def setsockopt( @@ -184,7 +197,7 @@ def setsockopt( ) -> None: pass - async def bind(self, address: Any) -> None: + async def bind(self, address: AddressFormat) -> None: pass def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None: @@ -252,7 +265,9 @@ async def getaddrinfo( ] async def getnameinfo( - self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + self, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, ) -> tuple[str, str]: raise NotImplementedError() @@ -268,8 +283,8 @@ async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None: (tsocket.AF_INET, "1.1.1.1"), (tsocket.AF_INET, "2.2.2.2"), (tsocket.AF_INET, "3.3.3.3"), - ] - ) + ], + ), ) with pytest.raises(FakeOSError): @@ -297,7 +312,9 @@ async def handler(stream: SendStream) -> None: async with trio.open_nursery() as nursery: # nursery.start is incorrectly typed, awaiting #2773 - listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) + value = await nursery.start(serve_tcp, handler, 0) + assert isinstance(value, list) + listeners = cast("list[SocketListener]", value) stream = await open_stream_to_socket_listener(listeners[0]) async with stream: assert await stream.receive_some(1) == b"x" @@ -313,14 +330,16 @@ async def handler(stream: SendStream) -> None: [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}], ) async def test_open_tcp_listeners_some_address_families_unavailable( - try_families: set[AddressFamily], fail_families: set[AddressFamily] + try_families: set[AddressFamily], + fail_families: set[AddressFamily], ) -> None: fsf = FakeSocketFactory( - 10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families} + 10, + raise_on_family=dict.fromkeys(fail_families, errno.EAFNOSUPPORT), ) tsocket.set_custom_socket_factory(fsf) tsocket.set_custom_hostname_resolver( - FakeHostnameResolver([(family, "foo") for family in try_families]) + FakeHostnameResolver([(family, "foo") for family in try_families]), ) should_succeed = try_families - fail_families @@ -352,7 +371,7 @@ async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None: ) tsocket.set_custom_socket_factory(fsf) tsocket.set_custom_hostname_resolver( - FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")]) + FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")]), ) with pytest.raises(OSError, match="nope") as exc_info: @@ -389,8 +408,9 @@ async def test_open_tcp_listeners_backlog() -> None: async def test_open_tcp_listeners_backlog_float_error() -> None: fsf = FakeSocketFactory(99) tsocket.set_custom_socket_factory(fsf) - for should_fail in (0.0, 2.18, 3.14, 9.75): + for should_fail in (0.0, 2.18, 3.15, 9.75): with pytest.raises( - TypeError, match=f"backlog must be an int or None, not {should_fail!r}" + TypeError, + match=f"backlog must be an int or None, not {should_fail!r}", ): await open_tcp_listeners(0, backlog=should_fail) # type: ignore[arg-type] diff --git a/src/trio/_tests/test_highlevel_open_tcp_stream.py b/src/trio/_tests/test_highlevel_open_tcp_stream.py index ce1b1ac1de..98adf7efea 100644 --- a/src/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/src/trio/_tests/test_highlevel_open_tcp_stream.py @@ -3,7 +3,7 @@ import socket import sys from socket import AddressFamily, SocketKind -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING import attrs import pytest @@ -19,6 +19,8 @@ from trio.testing import Matcher, RaisesGroup if TYPE_CHECKING: + from collections.abc import Sequence + from trio.testing import MockClock if sys.version_info < (3, 11): @@ -156,12 +158,14 @@ async def test_local_address_real() -> None: local_address = "127.0.0.2" if can_bind_127_0_0_2() else "127.0.0.1" async with await open_tcp_stream( - *listener.getsockname(), local_address=local_address + *listener.getsockname(), + local_address=local_address, ) as client_stream: assert client_stream.socket.getsockname()[0] == local_address if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"): assert client_stream.socket.getsockopt( - trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT + trio.socket.IPPROTO_IP, + trio.socket.IP_BIND_ADDRESS_NO_PORT, ) server_sock, remote_addr = await listener.accept() await client_stream.aclose() @@ -172,13 +176,15 @@ async def test_local_address_real() -> None: # Trying to connect to an ipv4 address with the ipv6 wildcard # local_address should fail with pytest.raises( - OSError, match=r"^all attempts to connect* to *127\.0\.0\.\d:\d+ failed$" + OSError, + match=r"^all attempts to connect* to *127\.0\.0\.\d:\d+ failed$", ): await open_tcp_stream(*listener.getsockname(), local_address="::") # But the ipv4 wildcard address should work async with await open_tcp_stream( - *listener.getsockname(), local_address="0.0.0.0" + *listener.getsockname(), + local_address="0.0.0.0", ) as client_stream: server_sock, remote_addr = await listener.accept() server_sock.close() @@ -318,7 +324,9 @@ async def getaddrinfo( return [self._ip_to_gai_entry(ip) for ip in self.ip_order] async def getnameinfo( - self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + self, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, ) -> tuple[str, str]: raise NotImplementedError @@ -352,7 +360,8 @@ async def run_scenario( # If this is True, we require there to be an exception, and return # (exception, scenario object) expect_error: tuple[type[BaseException], ...] | type[BaseException] = (), - **kwargs: Any, + happy_eyeballs_delay: float | None = 0.25, + local_address: str | None = None, ) -> tuple[SocketType, Scenario] | tuple[BaseException, Scenario]: supported_families = set() if ipv4_supported: @@ -364,7 +373,12 @@ async def run_scenario( trio.socket.set_custom_socket_factory(scenario) try: - stream = await open_tcp_stream("test.example.com", port, **kwargs) + stream = await open_tcp_stream( + "test.example.com", + port, + happy_eyeballs_delay=happy_eyeballs_delay, + local_address=local_address, + ) assert expect_error == () scenario.check(stream.socket) return (stream.socket, scenario) @@ -376,38 +390,44 @@ async def run_scenario( async def test_one_host_quick_success(autojump_clock: MockClock) -> None: - sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")]) + sock, _scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")]) assert isinstance(sock, FakeSocket) assert sock.ip == "1.2.3.4" assert trio.current_time() == 0.123 async def test_one_host_slow_success(autojump_clock: MockClock) -> None: - sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")]) + sock, _scenario = await run_scenario(81, [("1.2.3.4", 100, "success")]) assert isinstance(sock, FakeSocket) assert sock.ip == "1.2.3.4" assert trio.current_time() == 100 async def test_one_host_quick_fail(autojump_clock: MockClock) -> None: - exc, scenario = await run_scenario( - 82, [("1.2.3.4", 0.123, "error")], expect_error=OSError + exc, _scenario = await run_scenario( + 82, + [("1.2.3.4", 0.123, "error")], + expect_error=OSError, ) assert isinstance(exc, OSError) assert trio.current_time() == 0.123 async def test_one_host_slow_fail(autojump_clock: MockClock) -> None: - exc, scenario = await run_scenario( - 83, [("1.2.3.4", 100, "error")], expect_error=OSError + exc, _scenario = await run_scenario( + 83, + [("1.2.3.4", 100, "error")], + expect_error=OSError, ) assert isinstance(exc, OSError) assert trio.current_time() == 100 async def test_one_host_failed_after_connect(autojump_clock: MockClock) -> None: - exc, scenario = await run_scenario( - 83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt + exc, _scenario = await run_scenario( + 83, + [("1.2.3.4", 1, "postconnect_fail")], + expect_error=KeyboardInterrupt, ) assert isinstance(exc, KeyboardInterrupt) @@ -532,7 +552,8 @@ async def test_all_fail(autojump_clock: MockClock) -> None: subexceptions = (Matcher(OSError, match="^sorry$"),) * 4 assert RaisesGroup( - *subexceptions, match="all attempts to connect to test.example.com:80 failed" + *subexceptions, + match="all attempts to connect to test.example.com:80 failed", ).matches(exc.__cause__) assert trio.current_time() == (0.1 + 0.2 + 10) @@ -643,7 +664,7 @@ async def test_handles_no_ipv6(autojump_clock: MockClock) -> None: async def test_no_hosts(autojump_clock: MockClock) -> None: - exc, scenario = await run_scenario(80, [], expect_error=OSError) + exc, _scenario = await run_scenario(80, [], expect_error=OSError) assert "no results found" in str(exc) diff --git a/src/trio/_tests/test_highlevel_open_unix_stream.py b/src/trio/_tests/test_highlevel_open_unix_stream.py index 38c31b8a4a..4441d751d2 100644 --- a/src/trio/_tests/test_highlevel_open_unix_stream.py +++ b/src/trio/_tests/test_highlevel_open_unix_stream.py @@ -12,7 +12,8 @@ assert not TYPE_CHECKING or sys.platform != "win32" skip_if_not_unix = pytest.mark.skipif( - not hasattr(socket, "AF_UNIX"), reason="Needs unix socket support" + not hasattr(socket, "AF_UNIX"), + reason="Needs unix socket support", ) @@ -79,6 +80,7 @@ async def test_open_unix_socket() -> None: @pytest.mark.skipif(hasattr(socket, "AF_UNIX"), reason="Test for non-unix platforms") async def test_error_on_no_unix() -> None: with pytest.raises( - RuntimeError, match="^Unix sockets are not supported on this platform$" + RuntimeError, + match="^Unix sockets are not supported on this platform$", ): await open_unix_socket("") diff --git a/src/trio/_tests/test_highlevel_serve_listeners.py b/src/trio/_tests/test_highlevel_serve_listeners.py index 1ce886eddb..9268555b32 100644 --- a/src/trio/_tests/test_highlevel_serve_listeners.py +++ b/src/trio/_tests/test_highlevel_serve_listeners.py @@ -2,7 +2,7 @@ import errno from functools import partial -from typing import TYPE_CHECKING, Awaitable, Callable, NoReturn +from typing import TYPE_CHECKING, NoReturn, cast import attrs @@ -19,6 +19,8 @@ ) if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + import pytest from trio._channel import MemoryReceiveChannel, MemorySendChannel @@ -29,7 +31,7 @@ StapledMemoryStream = StapledStream[MemorySendStream, MemoryReceiveStream] -@attrs.define(hash=False, eq=False, slots=False) +@attrs.define(eq=False, slots=False) class MemoryListener(trio.abc.Listener[StapledMemoryStream]): closed: bool = False accepted_streams: list[trio.abc.Stream] = attrs.Factory(list) @@ -94,9 +96,13 @@ async def do_tests(parent_nursery: Nursery) -> None: parent_nursery.cancel_scope.cancel() async with trio.open_nursery() as nursery: - l2: list[MemoryListener] = await nursery.start( - trio.serve_listeners, handler, listeners + value = await nursery.start( + trio.serve_listeners, + handler, + listeners, ) + assert isinstance(value, list) + l2 = cast("list[MemoryListener]", value) assert l2 == listeners # This is just split into another function because gh-136 isn't # implemented yet @@ -123,7 +129,8 @@ def check_error(e: BaseException) -> bool: async def test_serve_listeners_accept_capacity_error( - autojump_clock: MockClock, caplog: pytest.LogCaptureFixture + autojump_clock: MockClock, + caplog: pytest.LogCaptureFixture, ) -> None: listener = MemoryListener() @@ -155,7 +162,8 @@ class Done(Exception): pass async def connection_watcher( - *, task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED + *, + task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED, ) -> NoReturn: async with trio.open_nursery() as nursery: task_status.started(nursery) @@ -166,14 +174,16 @@ async def connection_watcher( # the exception is wrapped twice because we open two nested nurseries with RaisesGroup(RaisesGroup(Done)): async with trio.open_nursery() as nursery: - handler_nursery: trio.Nursery = await nursery.start(connection_watcher) + value = await nursery.start(connection_watcher) + assert isinstance(value, trio.Nursery) + handler_nursery: trio.Nursery = value await nursery.start( partial( trio.serve_listeners, handler, [listener], handler_nursery=handler_nursery, - ) + ), ) for _ in range(10): nursery.start_soon(listener.connect) diff --git a/src/trio/_tests/test_highlevel_socket.py b/src/trio/_tests/test_highlevel_socket.py index 976a3b5e04..a03efb0180 100644 --- a/src/trio/_tests/test_highlevel_socket.py +++ b/src/trio/_tests/test_highlevel_socket.py @@ -3,7 +3,7 @@ import errno import socket as stdlib_socket import sys -from typing import Sequence +from typing import TYPE_CHECKING import pytest @@ -16,7 +16,16 @@ ) from .test_socket import setsockopt_tests +if TYPE_CHECKING: + from collections.abc import Sequence + +@pytest.mark.xfail( + sys.platform == "darwin" and sys.version_info[:3] == (3, 13, 1), + reason="TODO: This started failing in CI after 3.13.1", + raises=OSError, + strict=True, +) async def test_SocketStream_basics() -> None: # stdlib socket bad (even if connected) stdlib_a, stdlib_b = stdlib_socket.socketpair() @@ -27,7 +36,8 @@ async def test_SocketStream_basics() -> None: # DGRAM socket bad with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock: with pytest.raises( - ValueError, match="^SocketStream requires a SOCK_STREAM socket$" + ValueError, + match="^SocketStream requires a SOCK_STREAM socket$", ): # TODO: does not raise an error? SocketStream(sock) @@ -130,7 +140,10 @@ async def waiter(nursery: _core.Nursery) -> None: async def test_SocketStream_generic() -> None: - async def stream_maker() -> tuple[SocketStream, SocketStream]: + async def stream_maker() -> tuple[ + SocketStream, + SocketStream, + ]: left, right = tsocket.socketpair() return SocketStream(left), SocketStream(right) @@ -155,7 +168,8 @@ async def test_SocketListener() -> None: with tsocket.socket(type=tsocket.SOCK_DGRAM) as s: await s.bind(("127.0.0.1", 0)) with pytest.raises( - ValueError, match="^SocketListener requires a SOCK_STREAM socket$" + ValueError, + match="^SocketListener requires a SOCK_STREAM socket$", ) as excinfo: SocketListener(s) excinfo.match(r".*SOCK_STREAM") @@ -166,7 +180,8 @@ async def test_SocketListener() -> None: with tsocket.socket() as s: await s.bind(("127.0.0.1", 0)) with pytest.raises( - ValueError, match="^SocketListener requires a listening socket$" + ValueError, + match="^SocketListener requires a listening socket$", ) as excinfo: SocketListener(s) excinfo.match(r".*listen") @@ -228,22 +243,39 @@ def getsockopt(self, /, level: int, optname: int) -> int: ... @overload def getsockopt( # noqa: F811 - self, /, level: int, optname: int, buflen: int + self, + /, + level: int, + optname: int, + buflen: int, ) -> bytes: ... def getsockopt( # noqa: F811 - self, /, level: int, optname: int, buflen: int | None = None + self, + /, + level: int, + optname: int, + buflen: int | None = None, ) -> int | bytes: return True @overload def setsockopt( - self, /, level: int, optname: int, value: int | Buffer + self, + /, + level: int, + optname: int, + value: int | Buffer, ) -> None: ... @overload def setsockopt( # noqa: F811 - self, /, level: int, optname: int, value: None, optlen: int + self, + /, + level: int, + optname: int, + value: None, + optlen: int, ) -> None: ... def setsockopt( # noqa: F811 @@ -276,7 +308,7 @@ async def accept(self) -> tuple[SocketType, object]: OSError(errno.EFAULT, "attempt to write to read-only memory"), OSError(errno.ENOBUFS, "out of buffers"), fake_server_sock, - ] + ], ) listener = SocketListener(fake_listen_sock) diff --git a/src/trio/_tests/test_highlevel_ssl_helpers.py b/src/trio/_tests/test_highlevel_ssl_helpers.py index 53f687d7c3..e42f311981 100644 --- a/src/trio/_tests/test_highlevel_ssl_helpers.py +++ b/src/trio/_tests/test_highlevel_ssl_helpers.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING, Any, NoReturn +from typing import TYPE_CHECKING, NoReturn, cast import attrs import pytest @@ -66,7 +66,11 @@ async def getaddrinfo( ]: return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)] - async def getnameinfo(self, *args: Any) -> NoReturn: # pragma: no cover + async def getnameinfo( + self, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, + ) -> NoReturn: # pragma: no cover raise NotImplementedError @@ -79,13 +83,17 @@ async def test_open_ssl_over_tcp_stream_and_everything_else( # TODO: this function wraps an SSLListener around a SocketListener, this is illegal # according to current type hints, and probably for good reason. But there should # maybe be a different wrapper class/function that could be used instead? - res: list[SSLListener[SocketListener]] = ( # type: ignore[type-var] - await nursery.start( - partial( - serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1" - ) - ) + value = await nursery.start( + partial( + serve_ssl_over_tcp, + echo_handler, + 0, + SERVER_CTX, + host="127.0.0.1", + ), ) + assert isinstance(value, list) + res = cast("list[SSLListener[SocketListener]]", value) # type: ignore[type-var] (listener,) = res async with listener: # listener.transport_listener is of type Listener[Stream] @@ -105,7 +113,9 @@ async def test_open_ssl_over_tcp_stream_and_everything_else( # We have the trust but not the hostname # (checks custom ssl_context + hostname checking) stream = await open_ssl_over_tcp_stream( - "xyzzy.example.org", 80, ssl_context=client_ctx + "xyzzy.example.org", + 80, + ssl_context=client_ctx, ) async with stream: with pytest.raises(trio.BrokenResourceError): @@ -113,7 +123,9 @@ async def test_open_ssl_over_tcp_stream_and_everything_else( # This one should work! stream = await open_ssl_over_tcp_stream( - "trio-test-1.example.org", 80, ssl_context=client_ctx + "trio-test-1.example.org", + 80, + ssl_context=client_ctx, ) async with stream: assert isinstance(stream, trio.SSLStream) @@ -149,7 +161,10 @@ async def test_open_ssl_over_tcp_listeners() -> None: assert not listener._https_compatible (listener,) = await open_ssl_over_tcp_listeners( - 0, SERVER_CTX, host="127.0.0.1", https_compatible=True + 0, + SERVER_CTX, + host="127.0.0.1", + https_compatible=True, ) async with listener: assert listener._https_compatible diff --git a/src/trio/_tests/test_path.py b/src/trio/_tests/test_path.py index af29a0604b..aa383e7255 100644 --- a/src/trio/_tests/test_path.py +++ b/src/trio/_tests/test_path.py @@ -2,7 +2,7 @@ import os import pathlib -from typing import TYPE_CHECKING, Type, Union +from typing import TYPE_CHECKING, Union import pytest @@ -28,12 +28,12 @@ def method_pair( @pytest.mark.skipif(os.name == "nt", reason="OS is not posix") -async def test_instantiate_posix() -> None: +def test_instantiate_posix() -> None: assert isinstance(trio.Path(), trio.PosixPath) @pytest.mark.skipif(os.name != "nt", reason="OS is not Windows") -async def test_instantiate_windows() -> None: +def test_instantiate_windows() -> None: assert isinstance(trio.Path(), trio.WindowsPath) @@ -44,15 +44,15 @@ async def test_open_is_async_context_manager(path: trio.Path) -> None: assert f.closed -async def test_magic() -> None: +def test_magic() -> None: path = trio.Path("test") assert str(path) == "test" assert bytes(path) == b"test" -EitherPathType = Union[Type[trio.Path], Type[pathlib.Path]] -PathOrStrType = Union[EitherPathType, Type[str]] +EitherPathType = Union[type[trio.Path], type[pathlib.Path]] +PathOrStrType = Union[EitherPathType, type[str]] cls_pairs: list[tuple[EitherPathType, EitherPathType]] = [ (trio.Path, pathlib.Path), (pathlib.Path, trio.Path), @@ -61,7 +61,7 @@ async def test_magic() -> None: @pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs) -async def test_cmp_magic(cls_a: EitherPathType, cls_b: EitherPathType) -> None: +def test_cmp_magic(cls_a: EitherPathType, cls_b: EitherPathType) -> None: a, b = cls_a(""), cls_b("") assert a == b assert not a != b # noqa: SIM202 # negate-not-equal-op @@ -88,7 +88,7 @@ async def test_cmp_magic(cls_a: EitherPathType, cls_b: EitherPathType) -> None: @pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs_str) -async def test_div_magic(cls_a: PathOrStrType, cls_b: PathOrStrType) -> None: +def test_div_magic(cls_a: PathOrStrType, cls_b: PathOrStrType) -> None: a, b = cls_a("a"), cls_b("b") result = a / b # type: ignore[operator] @@ -98,24 +98,27 @@ async def test_div_magic(cls_a: PathOrStrType, cls_b: PathOrStrType) -> None: @pytest.mark.parametrize( - ("cls_a", "cls_b"), [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] + ("cls_a", "cls_b"), + [(trio.Path, pathlib.Path), (trio.Path, trio.Path)], ) @pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"]) -async def test_hash_magic( - cls_a: EitherPathType, cls_b: EitherPathType, path: str +def test_hash_magic( + cls_a: EitherPathType, + cls_b: EitherPathType, + path: str, ) -> None: a, b = cls_a(path), cls_b(path) assert hash(a) == hash(b) -async def test_forwarded_properties(path: trio.Path) -> None: +def test_forwarded_properties(path: trio.Path) -> None: # use `name` as a representative of forwarded properties assert "name" in dir(path) assert path.name == "test" -async def test_async_method_signature(path: trio.Path) -> None: +def test_async_method_signature(path: trio.Path) -> None: # use `resolve` as a representative of wrapped methods assert path.resolve.__name__ == "resolve" @@ -135,7 +138,7 @@ async def test_compare_async_stat_methods(method_name: str) -> None: assert result == async_result -async def test_invalid_name_not_wrapped(path: trio.Path) -> None: +def test_invalid_name_not_wrapped(path: trio.Path) -> None: with pytest.raises(AttributeError): getattr(path, "invalid_fake_attr") # noqa: B009 # "get-attr-with-constant" @@ -151,7 +154,7 @@ async def test_async_methods_rewrap(method_name: str) -> None: assert str(result) == str(async_result) -async def test_forward_methods_rewrap(path: trio.Path, tmp_path: pathlib.Path) -> None: +def test_forward_methods_rewrap(path: trio.Path, tmp_path: pathlib.Path) -> None: with_name = path.with_name("foo") with_suffix = path.with_suffix(".py") @@ -161,7 +164,7 @@ async def test_forward_methods_rewrap(path: trio.Path, tmp_path: pathlib.Path) - assert with_suffix == tmp_path / "test.py" -async def test_forward_properties_rewrap(path: trio.Path) -> None: +def test_forward_properties_rewrap(path: trio.Path) -> None: assert isinstance(path.parent, trio.Path) @@ -171,7 +174,7 @@ async def test_forward_methods_without_rewrap(path: trio.Path) -> None: assert path.as_uri().startswith("file:///") -async def test_repr() -> None: +def test_repr() -> None: path = trio.Path(".") assert repr(path) == "trio.Path('.')" @@ -190,7 +193,7 @@ async def test_path_wraps_path( assert wrapped == result -async def test_path_nonpath() -> None: +def test_path_nonpath() -> None: with pytest.raises(TypeError): trio.Path(1) # type: ignore @@ -264,7 +267,7 @@ async def test_classmethods() -> None: ], ) def test_wrapping_without_docstrings( - wrapper: Callable[[Callable[[], None]], Callable[[], None]] + wrapper: Callable[[Callable[[], None]], Callable[[], None]], ) -> None: @wrapper def func_without_docstring() -> None: ... # pragma: no cover diff --git a/src/trio/_tests/test_repl.py b/src/trio/_tests/test_repl.py index fbfdb07a05..be9338ce4c 100644 --- a/src/trio/_tests/test_repl.py +++ b/src/trio/_tests/test_repl.py @@ -42,7 +42,7 @@ def test_build_raw_input() -> None: # In 3.10 or later, types.FunctionType (used internally) will automatically # attach __builtins__ to the function objects. However we need to explicitly -# include it for 3.8 & 3.9 +# include it for 3.9 support def build_locals() -> dict[str, object]: return {"__builtins__": __builtins__} @@ -76,11 +76,11 @@ async def test_basic_interaction( # import works "import sys", "sys.stdout.write('hello stdout\\n')", - ] + ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) - out, err = capsys.readouterr() + out, _err = capsys.readouterr() assert out.splitlines() == ["x=1", "'hello'", "2", "4", "hello stdout", "13"] @@ -89,7 +89,7 @@ async def test_system_exits_quit_interpreter(monkeypatch: pytest.MonkeyPatch) -> raw_input = build_raw_input( [ "raise SystemExit", - ] + ], ) monkeypatch.setattr(console, "raw_input", raw_input) with pytest.raises(SystemExit): @@ -103,19 +103,18 @@ async def test_KI_interrupts( console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals()) raw_input = build_raw_input( [ - "from trio._util import signal_raise", "import signal, trio, trio.lowlevel", "async def f():", " trio.lowlevel.spawn_system_task(" " trio.to_thread.run_sync," - " signal_raise,signal.SIGINT," + " signal.raise_signal, signal.SIGINT," " )", # just awaiting this kills the test runner?! " await trio.sleep_forever()", " print('should not see this')", "", "await f()", "print('AFTER KeyboardInterrupt')", - ] + ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) @@ -138,11 +137,11 @@ async def test_system_exits_in_exc_group( "", "raise BaseExceptionGroup('', [RuntimeError(), SystemExit()])", "print('AFTER BaseExceptionGroup')", - ] + ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) - out, err = capsys.readouterr() + out, _err = capsys.readouterr() # assert that raise SystemExit in an exception group # doesn't quit assert "AFTER BaseExceptionGroup" in out @@ -162,11 +161,11 @@ async def test_system_exits_in_nested_exc_group( "raise BaseExceptionGroup(", " '', [BaseExceptionGroup('', [RuntimeError(), SystemExit()])])", "print('AFTER BaseExceptionGroup')", - ] + ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) - out, err = capsys.readouterr() + out, _err = capsys.readouterr() # assert that raise SystemExit in an exception group # doesn't quit assert "AFTER BaseExceptionGroup" in out @@ -182,7 +181,7 @@ async def test_base_exception_captured( # The statement after raise should still get executed "raise BaseException", "print('AFTER BaseException')", - ] + ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) @@ -202,11 +201,11 @@ async def test_exc_group_captured( # The statement after raise should still get executed "raise ExceptionGroup('', [KeyError()])", "print('AFTER ExceptionGroup')", - ] + ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) - out, err = capsys.readouterr() + out, _err = capsys.readouterr() assert "AFTER ExceptionGroup" in out @@ -224,7 +223,7 @@ async def test_base_exception_capture_from_coroutine( # be executed "await async_func_raises_base_exception()", "print('AFTER BaseException')", - ] + ], ) monkeypatch.setattr(console, "raw_input", raw_input) await trio._repl.run_repl(console) diff --git a/src/trio/_tests/test_signals.py b/src/trio/_tests/test_signals.py index 5e639652ef..d149b86575 100644 --- a/src/trio/_tests/test_signals.py +++ b/src/trio/_tests/test_signals.py @@ -10,7 +10,6 @@ from .. import _core from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver -from .._util import signal_raise if TYPE_CHECKING: from types import FrameType @@ -21,16 +20,16 @@ async def test_open_signal_receiver() -> None: with open_signal_receiver(signal.SIGILL) as receiver: # Raise it a few times, to exercise signal coalescing, both at the # call_soon level and at the SignalQueue level - signal_raise(signal.SIGILL) - signal_raise(signal.SIGILL) + signal.raise_signal(signal.SIGILL) + signal.raise_signal(signal.SIGILL) await _core.wait_all_tasks_blocked() - signal_raise(signal.SIGILL) + signal.raise_signal(signal.SIGILL) await _core.wait_all_tasks_blocked() async for signum in receiver: # pragma: no branch assert signum == signal.SIGILL break assert get_pending_signal_count(receiver) == 0 - signal_raise(signal.SIGILL) + signal.raise_signal(signal.SIGILL) async for signum in receiver: # pragma: no branch assert signum == signal.SIGILL break @@ -43,7 +42,8 @@ async def test_open_signal_receiver() -> None: async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None: orig = signal.getsignal(signal.SIGILL) with pytest.raises( - ValueError, match="(signal number out of range|invalid signal value)$" + ValueError, + match="(signal number out of range|invalid signal value)$", ): with open_signal_receiver(signal.SIGILL, 1234567): pass # pragma: no cover @@ -51,7 +51,7 @@ async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> No assert signal.getsignal(signal.SIGILL) is orig -async def test_open_signal_receiver_empty_fail() -> None: +def test_open_signal_receiver_empty_fail() -> None: with pytest.raises(TypeError, match="No signals were provided"): with open_signal_receiver(): pass @@ -100,8 +100,8 @@ async def test_open_signal_receiver_no_starvation() -> None: print(signal.getsignal(signal.SIGILL)) previous = None for _ in range(10): - signal_raise(signal.SIGILL) - signal_raise(signal.SIGFPE) + signal.raise_signal(signal.SIGILL) + signal.raise_signal(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() if previous is None: previous = await receiver.__anext__() @@ -133,8 +133,8 @@ def direct_handler(signo: int, frame: FrameType | None) -> None: # before we exit the with block: with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler): with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: - signal_raise(signal.SIGILL) - signal_raise(signal.SIGFPE) + signal.raise_signal(signal.SIGILL) + signal.raise_signal(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() assert delivered_directly == {signal.SIGILL, signal.SIGFPE} delivered_directly.clear() @@ -144,8 +144,8 @@ def direct_handler(signo: int, frame: FrameType | None) -> None: # we exit the with block: with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler): with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: - signal_raise(signal.SIGILL) - signal_raise(signal.SIGFPE) + signal.raise_signal(signal.SIGILL) + signal.raise_signal(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() assert get_pending_signal_count(receiver) == 2 assert delivered_directly == {signal.SIGILL, signal.SIGFPE} @@ -156,14 +156,14 @@ def direct_handler(signo: int, frame: FrameType | None) -> None: print(3) with _signal_handler({signal.SIGILL}, signal.SIG_IGN): with open_signal_receiver(signal.SIGILL) as receiver: - signal_raise(signal.SIGILL) + signal.raise_signal(signal.SIGILL) await wait_run_sync_soon_idempotent_queue_barrier() # test passes if the process reaches this point without dying print(4) with _signal_handler({signal.SIGILL}, signal.SIG_IGN): with open_signal_receiver(signal.SIGILL) as receiver: - signal_raise(signal.SIGILL) + signal.raise_signal(signal.SIGILL) await wait_run_sync_soon_idempotent_queue_barrier() assert get_pending_signal_count(receiver) == 1 # test passes if the process reaches this point without dying @@ -176,8 +176,8 @@ def raise_handler(signum: int, frame: FrameType | None) -> NoReturn: with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler): with pytest.raises(RuntimeError) as excinfo: with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: - signal_raise(signal.SIGILL) - signal_raise(signal.SIGFPE) + signal.raise_signal(signal.SIGILL) + signal.raise_signal(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() assert get_pending_signal_count(receiver) == 2 exc = excinfo.value diff --git a/src/trio/_tests/test_socket.py b/src/trio/_tests/test_socket.py index b98b3246e9..3e960bd9a4 100644 --- a/src/trio/_tests/test_socket.py +++ b/src/trio/_tests/test_socket.py @@ -6,33 +6,45 @@ import socket as stdlib_socket import sys import tempfile +from pathlib import Path from socket import AddressFamily, SocketKind -from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Union +from typing import TYPE_CHECKING, Union, cast import attrs import pytest from .. import _core, socket as tsocket -from .._core._tests.tutil import binds_ipv6, creates_ipv6 -from .._socket import _NUMERIC_ONLY, SocketType, _SocketType, _try_sync +from .._core._tests.tutil import binds_ipv6, can_create_ipv6, creates_ipv6 +from .._socket import _NUMERIC_ONLY, AddressFormat, SocketType, _SocketType, _try_sync from ..testing import assert_checkpoints, wait_all_tasks_blocked if TYPE_CHECKING: + from collections.abc import Callable + from typing_extensions import TypeAlias from .._highlevel_socket import SocketStream - GaiTuple: TypeAlias = Tuple[ + GaiTuple: TypeAlias = tuple[ AddressFamily, SocketKind, int, str, - Union[Tuple[str, int], Tuple[str, int, int, int]], + Union[tuple[str, int], tuple[str, int, int, int]], + ] + GetAddrInfoResponse: TypeAlias = list[GaiTuple] + GetAddrInfoArgs: TypeAlias = tuple[ + Union[str, bytes, None], + Union[str, bytes, int, None], + int, + int, + int, + int, ] - GetAddrInfoResponse: TypeAlias = List[GaiTuple] else: GaiTuple: object GetAddrInfoResponse = object + GetAddrInfoArgs = object ################################################################ # utils @@ -40,32 +52,75 @@ class MonkeypatchedGAI: - def __init__(self, orig_getaddrinfo: Callable[..., GetAddrInfoResponse]) -> None: + __slots__ = ("_orig_getaddrinfo", "_responses", "record") + + def __init__( + self, + orig_getaddrinfo: Callable[ + [str | bytes | None, str | bytes | int | None, int, int, int, int], + GetAddrInfoResponse, + ], + ) -> None: self._orig_getaddrinfo = orig_getaddrinfo - self._responses: dict[tuple[Any, ...], GetAddrInfoResponse | str] = {} - self.record: list[tuple[Any, ...]] = [] + self._responses: dict[ + GetAddrInfoArgs, + GetAddrInfoResponse | str, + ] = {} + self.record: list[GetAddrInfoArgs] = [] # get a normalized getaddrinfo argument tuple - def _frozenbind(self, *args: Any, **kwargs: Any) -> tuple[Any, ...]: + def _frozenbind( + self, + host: str | bytes | None, + port: str | bytes | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> GetAddrInfoArgs: sig = inspect.signature(self._orig_getaddrinfo) - bound = sig.bind(*args, **kwargs) + bound = sig.bind(host, port, family=family, type=type, proto=proto, flags=flags) bound.apply_defaults() frozenbound = bound.args assert not bound.kwargs return frozenbound def set( - self, response: GetAddrInfoResponse | str, *args: Any, **kwargs: Any + self, + response: GetAddrInfoResponse | str, + host: str | bytes | None, + port: str | bytes | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, ) -> None: - self._responses[self._frozenbind(*args, **kwargs)] = response - - def getaddrinfo(self, *args: Any, **kwargs: Any) -> GetAddrInfoResponse | str: - bound = self._frozenbind(*args, **kwargs) + self._responses[ + self._frozenbind( + host, + port, + family=family, + type=type, + proto=proto, + flags=flags, + ) + ] = response + + def getaddrinfo( + self, + host: str | bytes | None, + port: str | bytes | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> GetAddrInfoResponse | str: + bound = self._frozenbind(host, port, family, type, proto, flags) self.record.append(bound) if bound in self._responses: return self._responses[bound] - elif bound[-1] & stdlib_socket.AI_NUMERICHOST: - return self._orig_getaddrinfo(*args, **kwargs) + elif flags & stdlib_socket.AI_NUMERICHOST: + return self._orig_getaddrinfo(host, port, family, type, proto, flags) else: raise RuntimeError(f"gai called with unexpected arguments {bound}") @@ -134,7 +189,7 @@ def interesting_fields( tuple[str, int] | tuple[str, int, int] | tuple[str, int, int, int], ]: # (family, type, proto, canonname, sockaddr) - family, type_, proto, canonname, sockaddr = gai_tup + family, type_, _proto, _canonname, sockaddr = gai_tup return (family, type_, sockaddr) def filtered( @@ -321,9 +376,10 @@ async def test_sniff_sockopts() -> None: from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM # generate the combinations of families/types we're testing: + families = (AF_INET, AF_INET6) if can_create_ipv6 else (AF_INET,) sockets = [ stdlib_socket.socket(family, type_) - for family in [AF_INET, AF_INET6] + for family in families for type_ in [SOCK_DGRAM, SOCK_STREAM] ] for socket in sockets: @@ -400,6 +456,12 @@ async def test_SocketType_basics() -> None: sock.close() +@pytest.mark.xfail( + sys.platform == "darwin" and sys.version_info[:3] == (3, 13, 1), + reason="TODO: This started failing in CI after 3.13.1", + raises=OSError, + strict=True, +) async def test_SocketType_setsockopt() -> None: sock = tsocket.socket() with sock as _: @@ -472,7 +534,8 @@ async def test_SocketType_shutdown() -> None: ], ) async def test_SocketType_simple_server( - address: str, socket_type: AddressFamily + address: str, + socket_type: AddressFamily, ) -> None: # listen, bind, accept, connect, getpeername, getsockname listener = tsocket.socket(socket_type) @@ -556,7 +619,8 @@ def pad(addr: tuple[str | int, ...]) -> tuple[str | int, ...]: return addr def assert_eq( - actual: tuple[str | int, ...], expected: tuple[str | int, ...] + actual: tuple[str | int, ...], + expected: tuple[str | int, ...], ) -> None: assert pad(expected) == pad(actual) @@ -588,12 +652,14 @@ async def res( | tuple[str, str] | tuple[str, str, int] | tuple[str, str, int, int] - ) - ) -> Any: - return await sock._resolve_address_nocp( + ), + ) -> tuple[str | int, ...]: + value = await sock._resolve_address_nocp( args, local=local, # noqa: B023 # local is not bound in function definition ) + assert isinstance(value, tuple) + return cast("tuple[Union[str, int], ...]", value) assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80)) if v6: @@ -639,7 +705,8 @@ async def res( # smoke test the basic functionality... try: netlink_sock = tsocket.socket( - family=tsocket.AF_NETLINK, type=tsocket.SOCK_DGRAM + family=tsocket.AF_NETLINK, + type=tsocket.SOCK_DGRAM, ) except (AttributeError, OSError): pass @@ -787,15 +854,20 @@ async def test_SocketType_connect_paths() -> None: # nose -- and then swap it back out again before we hit # wait_socket_writable, which insists on a real socket. class CancelSocket(stdlib_socket.socket): - def connect(self, *args: Any, **kwargs: Any) -> None: + def connect( + self, + address: AddressFormat, + ) -> None: # accessing private method only available in _SocketType assert isinstance(sock, _SocketType) cancel_scope.cancel() sock._sock = stdlib_socket.fromfd( - self.detach(), self.family, self.type + self.detach(), + self.family, + self.type, ) - sock._sock.connect(*args, **kwargs) + sock._sock.connect(address) # If connect *doesn't* raise, then pretend it did raise BlockingIOError # pragma: no cover @@ -842,13 +914,17 @@ async def test_resolve_address_exception_in_connect_closes_socket() -> None: with tsocket.socket() as sock: async def _resolve_address_nocp( - self: Any, *args: Any, **kwargs: Any + address: AddressFormat, + *, + local: bool, ) -> None: + assert address == "" + assert not local cancel_scope.cancel() await _core.checkpoint() assert isinstance(sock, _SocketType) - sock._resolve_address_nocp = _resolve_address_nocp # type: ignore[method-assign, assignment] + sock._resolve_address_nocp = _resolve_address_nocp # type: ignore[method-assign] with assert_checkpoints(): with pytest.raises(_core.Cancelled): await sock.connect("") @@ -991,7 +1067,9 @@ async def getaddrinfo( return ("custom_gai", host, port, family, type, proto, flags) async def getnameinfo( - self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + self, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, ) -> tuple[str, tuple[str, int] | tuple[str, int, int, int], int]: return ("custom_gni", sockaddr, flags) @@ -1067,7 +1145,7 @@ def socket( assert tsocket.set_custom_socket_factory(None) is csf -async def test_SocketType_is_abstract() -> None: +def test_SocketType_is_abstract() -> None: with pytest.raises(TypeError): tsocket.SocketType() @@ -1077,7 +1155,7 @@ async def test_unix_domain_socket() -> None: # Bind has a special branch to use a thread, since it has to do filesystem # traversal. Maybe connect should too? Not sure. - async def check_AF_UNIX(path: str | bytes) -> None: + async def check_AF_UNIX(path: str | bytes | os.PathLike[str]) -> None: with tsocket.socket(family=tsocket.AF_UNIX) as lsock: await lsock.bind(path) lsock.listen(10) @@ -1091,8 +1169,11 @@ async def check_AF_UNIX(path: str | bytes) -> None: # Can't use tmpdir fixture, because we can exceed the maximum AF_UNIX path # length on macOS. with tempfile.TemporaryDirectory() as tmpdir: - path = f"{tmpdir}/sock" - await check_AF_UNIX(path) + # Test passing various supported types as path + # Must use different filenames to prevent "address already in use" + await check_AF_UNIX(f"{tmpdir}/sock") + await check_AF_UNIX(Path(f"{tmpdir}/sock1")) + await check_AF_UNIX(os.fsencode(f"{tmpdir}/sock2")) try: cookie = os.urandom(20).hex().encode("ascii") diff --git a/src/trio/_tests/test_ssl.py b/src/trio/_tests/test_ssl.py index 8e780b2f9c..d271743c7a 100644 --- a/src/trio/_tests/test_ssl.py +++ b/src/trio/_tests/test_ssl.py @@ -11,10 +11,6 @@ from typing import ( TYPE_CHECKING, Any, - AsyncIterator, - Awaitable, - Callable, - Iterator, NoReturn, ) @@ -56,6 +52,8 @@ ) if TYPE_CHECKING: + from collections.abc import AsyncIterator, Awaitable, Callable, Iterator + from typing_extensions import TypeAlias from trio._core import MockClock @@ -116,11 +114,15 @@ def client_ctx(request: pytest.FixtureRequest) -> ssl.SSLContext: # The blocking socket server. def ssl_echo_serve_sync( - sock: stdlib_socket.socket, *, expect_fail: bool = False + sock: stdlib_socket.socket, + *, + expect_fail: bool = False, ) -> None: try: wrapped = SERVER_CTX.wrap_socket( - sock, server_side=True, suppress_ragged_eofs=False + sock, + server_side=True, + suppress_ragged_eofs=False, ) with wrapped: wrapped.do_handshake() @@ -166,8 +168,8 @@ def ssl_echo_serve_sync( # Fixture that gives a raw socket connected to a trio-test-1 echo server # (running in a thread). Useful for testing making connections with different # SSLContexts. -@asynccontextmanager # type: ignore[misc] # decorated contains Any -async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]: +@asynccontextmanager +async def ssl_echo_server_raw(expect_fail: bool = False) -> AsyncIterator[SocketStream]: a, b = stdlib_socket.socketpair() async with trio.open_nursery() as nursery: # Exiting the 'with a, b' context manager closes the sockets, which @@ -175,7 +177,8 @@ async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]: # nursery context manager to exit too. with a, b: nursery.start_soon( - trio.to_thread.run_sync, partial(ssl_echo_serve_sync, b, **kwargs) + trio.to_thread.run_sync, + partial(ssl_echo_serve_sync, b, expect_fail=expect_fail), ) yield SocketStream(tsocket.from_stdlib_socket(a)) @@ -183,11 +186,12 @@ async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]: # Fixture that gives a properly set up SSLStream connected to a trio-test-1 # echo server (running in a thread) -@asynccontextmanager # type: ignore[misc] # decorated contains Any +@asynccontextmanager async def ssl_echo_server( - client_ctx: SSLContext, **kwargs: Any + client_ctx: SSLContext, + expect_fail: bool = False, ) -> AsyncIterator[SSLStream[Stream]]: - async with ssl_echo_server_raw(**kwargs) as sock: + async with ssl_echo_server_raw(expect_fail=expect_fail) as sock: yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org") @@ -197,33 +201,20 @@ async def ssl_echo_server( # jakkdl: it seems to implement all the abstract methods (now), so I made it inherit # from Stream for the sake of typechecking. class PyOpenSSLEchoStream(Stream): - def __init__(self, sleeper: None = None) -> None: + def __init__( + self, + sleeper: Callable[[str], Awaitable[None]] | None = None, + ) -> None: ctx = SSL.Context(SSL.SSLv23_METHOD) # TLS 1.3 removes renegotiation support. Which is great for them, but # we still have to support versions before that, and that means we # need to test renegotiation support, which means we need to force this # to use a lower version where this test server can trigger - # renegotiations. Of course TLS 1.3 support isn't released yet, but - # I'm told that this will work once it is. (And once it is we can - # remove the pragma: no cover too.) Alternatively, we could switch to - # using TLSv1_2_METHOD. - # - # Discussion: https://github.com/pyca/pyopenssl/issues/624 - - # This is the right way, but we can't use it until this PR is in a - # released: - # https://github.com/pyca/pyopenssl/pull/861 - # - # if hasattr(SSL, "OP_NO_TLSv1_3"): - # ctx.set_options(SSL.OP_NO_TLSv1_3) - # - # Fortunately pyopenssl uses cryptography under the hood, so we can be - # confident that they're using the same version of openssl + # renegotiations. from cryptography.hazmat.bindings.openssl.binding import Binding b = Binding() - if hasattr(b.lib, "SSL_OP_NO_TLSv1_3"): - ctx.set_options(b.lib.SSL_OP_NO_TLSv1_3) + ctx.set_options(b.lib.SSL_OP_NO_TLSv1_3) # Unfortunately there's currently no way to say "use 1.3 or worse", we # can only disable specific versions. And if the two sides start @@ -239,12 +230,13 @@ def __init__(self, sleeper: None = None) -> None: self._pending_cleartext = bytearray() self._send_all_conflict_detector = ConflictDetector( - "simultaneous calls to PyOpenSSLEchoStream.send_all" + "simultaneous calls to PyOpenSSLEchoStream.send_all", ) self._receive_some_conflict_detector = ConflictDetector( - "simultaneous calls to PyOpenSSLEchoStream.receive_some" + "simultaneous calls to PyOpenSSLEchoStream.receive_some", ) + self.sleeper: Callable[[str], Awaitable[None]] if sleeper is None: async def no_op_sleeper(_: object) -> None: @@ -358,7 +350,10 @@ async def test_PyOpenSSLEchoStream_gives_resource_busy_errors() -> None: # PyOpenSSLEchoStream will notice and complain. async def do_test( - func1: str, args1: tuple[object, ...], func2: str, args2: tuple[object, ...] + func1: str, + args1: tuple[object, ...], + func2: str, + args2: tuple[object, ...], ) -> None: s = PyOpenSSLEchoStream() with RaisesGroup(Matcher(_core.BusyResourceError, "simultaneous")): @@ -369,20 +364,25 @@ async def do_test( await do_test("send_all", (b"x",), "send_all", (b"x",)) await do_test("send_all", (b"x",), "wait_send_all_might_not_block", ()) await do_test( - "wait_send_all_might_not_block", (), "wait_send_all_might_not_block", () + "wait_send_all_might_not_block", + (), + "wait_send_all_might_not_block", + (), ) await do_test("receive_some", (1,), "receive_some", (1,)) -@contextmanager # type: ignore[misc] # decorated contains Any +@contextmanager def virtual_ssl_echo_server( - client_ctx: SSLContext, **kwargs: Any + client_ctx: SSLContext, + sleeper: Callable[[str], Awaitable[None]] | None = None, ) -> Iterator[SSLStream[PyOpenSSLEchoStream]]: - fakesock = PyOpenSSLEchoStream(**kwargs) + fakesock = PyOpenSSLEchoStream(sleeper=sleeper) yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org") -def ssl_wrap_pair( +# Explicit "Any" is not allowed +def ssl_wrap_pair( # type: ignore[misc] client_ctx: SSLContext, client_transport: T_Stream, server_transport: T_Stream, @@ -401,7 +401,10 @@ def ssl_wrap_pair( **client_kwargs, ) server_ssl = SSLStream( - server_transport, SERVER_CTX, server_side=True, **server_kwargs + server_transport, + SERVER_CTX, + server_side=True, + **server_kwargs, ) return client_ssl, server_ssl @@ -409,23 +412,43 @@ def ssl_wrap_pair( MemoryStapledStream: TypeAlias = StapledStream[MemorySendStream, MemoryReceiveStream] -def ssl_memory_stream_pair(client_ctx: SSLContext, **kwargs: Any) -> tuple[ +def ssl_memory_stream_pair( + client_ctx: SSLContext, + client_kwargs: dict[str, str | bytes | bool | None] | None = None, + server_kwargs: dict[str, str | bytes | bool | None] | None = None, +) -> tuple[ SSLStream[MemoryStapledStream], SSLStream[MemoryStapledStream], ]: client_transport, server_transport = memory_stream_pair() - return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) + return ssl_wrap_pair( + client_ctx, + client_transport, + server_transport, + client_kwargs=client_kwargs, + server_kwargs=server_kwargs, + ) MyStapledStream: TypeAlias = StapledStream[SendStream, ReceiveStream] -def ssl_lockstep_stream_pair(client_ctx: SSLContext, **kwargs: Any) -> tuple[ +def ssl_lockstep_stream_pair( + client_ctx: SSLContext, + client_kwargs: dict[str, str | bytes | bool | None] | None = None, + server_kwargs: dict[str, str | bytes | bool | None] | None = None, +) -> tuple[ SSLStream[MyStapledStream], SSLStream[MyStapledStream], ]: client_transport, server_transport = lockstep_stream_pair() - return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) + return ssl_wrap_pair( + client_ctx, + client_transport, + server_transport, + client_kwargs=client_kwargs, + server_kwargs=server_kwargs, + ) # Simple smoke test for handshake/send/receive/shutdown talking to a @@ -462,13 +485,16 @@ async def test_ssl_server_basics(client_ctx: SSLContext) -> None: with a, b: server_sock = tsocket.from_stdlib_socket(b) server_transport = SSLStream( - SocketStream(server_sock), SERVER_CTX, server_side=True + SocketStream(server_sock), + SERVER_CTX, + server_side=True, ) assert server_transport.server_side def client() -> None: with client_ctx.wrap_socket( - a, server_hostname="trio-test-1.example.org" + a, + server_hostname="trio-test-1.example.org", ) as client_sock: client_sock.sendall(b"x") assert client_sock.recv(1) == b"y" @@ -612,7 +638,8 @@ async def test_renegotiation_simple(client_ctx: SSLContext) -> None: @slow async def test_renegotiation_randomized( - mock_clock: MockClock, client_ctx: SSLContext + mock_clock: MockClock, + client_ctx: SSLContext, ) -> None: # The only blocking things in this function are our random sleeps, so 0 is # a good threshold. @@ -687,7 +714,8 @@ async def expect(expected: bytes) -> None: # Our receive_some() call will get stuck when it hits send_all async def sleeper_with_slow_send_all(method: str) -> None: if method == "send_all": - await trio.sleep(100000) + # ignore ASYNC116, not sleep_forever, trying to test a large but finite sleep + await trio.sleep(100000) # noqa: ASYNC116 # And our wait_send_all_might_not_block call will give it time to get # stuck, and then start @@ -711,12 +739,14 @@ async def sleep_then_wait_writable() -> None: async def sleeper_with_slow_wait_writable_and_expect(method: str) -> None: if method == "wait_send_all_might_not_block": - await trio.sleep(100000) + # ignore ASYNC116, not sleep_forever, trying to test a large but finite sleep + await trio.sleep(100000) # noqa: ASYNC116 elif method == "expect": await trio.sleep(1000) with virtual_ssl_echo_server( - client_ctx, sleeper=sleeper_with_slow_wait_writable_and_expect + client_ctx, + sleeper=sleeper_with_slow_wait_writable_and_expect, ) as s: await send(b"x") s.transport_stream.renegotiate() @@ -747,7 +777,8 @@ async def do_wait_send_all_might_not_block(s: S) -> None: await s.wait_send_all_might_not_block() async def do_test( - func1: Callable[[S], Awaitable[None]], func2: Callable[[S], Awaitable[None]] + func1: Callable[[S], Awaitable[None]], + func2: Callable[[S], Awaitable[None]], ) -> None: s, _ = ssl_lockstep_stream_pair(client_ctx) with RaisesGroup(Matcher(_core.BusyResourceError, "another task")): @@ -835,7 +866,8 @@ async def test_send_all_empty_string(client_ctx: SSLContext) -> None: @pytest.mark.parametrize("https_compatible", [False, True]) async def test_SSLStream_generic( - client_ctx: SSLContext, https_compatible: bool + client_ctx: SSLContext, + https_compatible: bool, ) -> None: async def stream_maker() -> tuple[ SSLStream[MemoryStapledStream], @@ -1017,12 +1049,16 @@ async def test_ssl_over_ssl(client_ctx: SSLContext) -> None: client_0, server_0 = memory_stream_pair() client_1 = SSLStream( - client_0, client_ctx, server_hostname="trio-test-1.example.org" + client_0, + client_ctx, + server_hostname="trio-test-1.example.org", ) server_1 = SSLStream(server_0, SERVER_CTX, server_side=True) client_2 = SSLStream( - client_1, client_ctx, server_hostname="trio-test-1.example.org" + client_1, + client_ctx, + server_hostname="trio-test-1.example.org", ) server_2 = SSLStream(server_1, SERVER_CTX, server_side=True) @@ -1161,7 +1197,7 @@ async def server_expect_clean_eof() -> None: async def test_send_error_during_handshake(client_ctx: SSLContext) -> None: - client, server = ssl_memory_stream_pair(client_ctx) + client, _server = ssl_memory_stream_pair(client_ctx) async def bad_hook() -> NoReturn: raise KeyError @@ -1200,7 +1236,7 @@ async def client_side(cancel_scope: CancelScope) -> None: await client.do_handshake() -async def test_selected_alpn_protocol_before_handshake(client_ctx: SSLContext) -> None: +def test_selected_alpn_protocol_before_handshake(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1225,7 +1261,7 @@ async def test_selected_alpn_protocol_when_not_set(client_ctx: SSLContext) -> No assert client.selected_alpn_protocol() == server.selected_alpn_protocol() -async def test_selected_npn_protocol_before_handshake(client_ctx: SSLContext) -> None: +def test_selected_npn_protocol_before_handshake(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1254,7 +1290,7 @@ async def test_selected_npn_protocol_when_not_set(client_ctx: SSLContext) -> Non assert client.selected_npn_protocol() == server.selected_npn_protocol() -async def test_get_channel_binding_before_handshake(client_ctx: SSLContext) -> None: +def test_get_channel_binding_before_handshake(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1292,17 +1328,23 @@ async def test_getpeercert(client_ctx: SSLContext) -> None: async def test_SSLListener(client_ctx: SSLContext) -> None: async def setup( - **kwargs: Any, + https_compatible: bool = False, ) -> tuple[tsocket.SocketType, SSLListener[SocketStream], SSLStream[SocketStream]]: listen_sock = tsocket.socket() await listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(1) socket_listener = SocketListener(listen_sock) - ssl_listener = SSLListener(socket_listener, SERVER_CTX, **kwargs) + ssl_listener = SSLListener( + socket_listener, + SERVER_CTX, + https_compatible=https_compatible, + ) transport_client = await open_tcp_stream(*listen_sock.getsockname()) ssl_client = SSLStream( - transport_client, client_ctx, server_hostname="trio-test-1.example.org" + transport_client, + client_ctx, + server_hostname="trio-test-1.example.org", ) return listen_sock, ssl_listener, ssl_client diff --git a/src/trio/_tests/test_subprocess.py b/src/trio/_tests/test_subprocess.py index 0a70e7a974..bf6742064d 100644 --- a/src/trio/_tests/test_subprocess.py +++ b/src/trio/_tests/test_subprocess.py @@ -6,18 +6,17 @@ import signal import subprocess import sys -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager from functools import partial from pathlib import Path as SyncPath from signal import Signals from typing import ( TYPE_CHECKING, Any, - AsyncContextManager, - AsyncIterator, - Callable, NoReturn, ) +from unittest import mock import pytest @@ -83,15 +82,11 @@ def SLEEP(seconds: int) -> list[str]: return python(f"import time; time.sleep({seconds})") -def got_signal(proc: Process, sig: SignalType) -> bool: - if (not TYPE_CHECKING and posix) or sys.platform != "win32": - return proc.returncode == -sig - else: - return proc.returncode != 0 - - -@asynccontextmanager # type: ignore[misc] # Any in decorator -async def open_process_then_kill(*args: Any, **kwargs: Any) -> AsyncIterator[Process]: +@asynccontextmanager # type: ignore[misc] # Any in decorated +async def open_process_then_kill( + *args: Any, + **kwargs: Any, +) -> AsyncIterator[Process]: proc = await open_process(*args, **kwargs) try: yield proc @@ -100,11 +95,16 @@ async def open_process_then_kill(*args: Any, **kwargs: Any) -> AsyncIterator[Pro await proc.wait() -@asynccontextmanager # type: ignore[misc] # Any in decorator -async def run_process_in_nursery(*args: Any, **kwargs: Any) -> AsyncIterator[Process]: +@asynccontextmanager # type: ignore[misc] # Any in decorated +async def run_process_in_nursery( + *args: Any, + **kwargs: Any, +) -> AsyncIterator[Process]: async with _core.open_nursery() as nursery: kwargs.setdefault("check", False) - proc: Process = await nursery.start(partial(run_process, *args, **kwargs)) + value = await nursery.start(partial(run_process, *args, **kwargs)) + assert isinstance(value, Process) + proc: Process = value yield proc nursery.cancel_scope.cancel() @@ -115,7 +115,11 @@ async def run_process_in_nursery(*args: Any, **kwargs: Any) -> AsyncIterator[Pro ids=["open_process", "run_process in nursery"], ) -BackgroundProcessType: TypeAlias = Callable[..., AsyncContextManager[Process]] +# Explicit .../"Any" is not allowed +BackgroundProcessType: TypeAlias = Callable[ # type: ignore[misc] + ..., + AbstractAsyncContextManager[Process], +] @background_process_param @@ -131,10 +135,31 @@ async def test_basic(background_process: BackgroundProcessType) -> None: await proc.wait() assert proc.returncode == 1 assert repr(proc) == "".format( - EXIT_FALSE, "exited with status 1" + EXIT_FALSE, + "exited with status 1", ) +@background_process_param +async def test_basic_no_pidfd(background_process: BackgroundProcessType) -> None: + with mock.patch("trio._subprocess.can_try_pidfd_open", new=False): + async with background_process(EXIT_TRUE) as proc: + assert proc._pidfd is None + await proc.wait() + assert isinstance(proc, Process) + assert proc._pidfd is None + assert proc.returncode == 0 + assert repr(proc) == f"" + + async with background_process(EXIT_FALSE) as proc: + await proc.wait() + assert proc.returncode == 1 + assert repr(proc) == "".format( + EXIT_FALSE, + "exited with status 1", + ) + + @background_process_param async def test_auto_update_returncode( background_process: BackgroundProcessType, @@ -170,10 +195,31 @@ async def test_multi_wait(background_process: BackgroundProcessType) -> None: proc.kill() +@background_process_param +async def test_multi_wait_no_pidfd(background_process: BackgroundProcessType) -> None: + with mock.patch("trio._subprocess.can_try_pidfd_open", new=False): + async with background_process(SLEEP(10)) as proc: + # Check that wait (including multi-wait) tolerates being cancelled + async with _core.open_nursery() as nursery: + nursery.start_soon(proc.wait) + nursery.start_soon(proc.wait) + nursery.start_soon(proc.wait) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + # Now try waiting for real + async with _core.open_nursery() as nursery: + nursery.start_soon(proc.wait) + nursery.start_soon(proc.wait) + nursery.start_soon(proc.wait) + await wait_all_tasks_blocked() + proc.kill() + + COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR = python( "data = sys.stdin.buffer.read(); " "sys.stdout.buffer.write(data); " - "sys.stderr.buffer.write(data[::-1])" + "sys.stderr.buffer.write(data[::-1])", ) @@ -234,7 +280,7 @@ async def test_interactive(background_process: BackgroundProcessType) -> None: " request = int(line.strip())\n" " print(str(idx * 2) * request)\n" " print(str(idx * 2 + 1) * request * 2, file=sys.stderr)\n" - " idx += 1\n" + " idx += 1\n", ), stdin=subprocess.PIPE, stdout=subprocess.PIPE, @@ -246,7 +292,9 @@ async def expect(idx: int, request: int) -> None: async with _core.open_nursery() as nursery: async def drain_one( - stream: ReceiveStream, count: int, digit: int + stream: ReceiveStream, + count: int, + digit: int, ) -> None: while count > 0: result = await stream.receive_some(count) @@ -291,7 +339,10 @@ async def test_run() -> None: data = bytes(random.randint(0, 255) for _ in range(2**18)) result = await run_process( - CAT, stdin=data, capture_stdout=True, capture_stderr=True + CAT, + stdin=data, + capture_stdout=True, + capture_stderr=True, ) assert result.args == CAT assert result.returncode == 0 @@ -325,7 +376,8 @@ async def test_run() -> None: with pytest.raises(ValueError, match=pipe_stdout_error): await run_process(CAT, stdout=subprocess.PIPE) with pytest.raises( - ValueError, match=pipe_stdout_error.replace("stdout", "stderr", 1) + ValueError, + match=pipe_stdout_error.replace("stdout", "stderr", 1), ): await run_process(CAT, stderr=subprocess.PIPE) with pytest.raises( @@ -350,7 +402,10 @@ async def test_run_check() -> None: assert excinfo.value.stdout is None result = await run_process( - cmd, capture_stdout=True, capture_stderr=True, check=False + cmd, + capture_stdout=True, + capture_stderr=True, + check=False, ) assert result.args == cmd assert result.stdout == b"" @@ -361,7 +416,8 @@ async def test_run_check() -> None: @skip_if_fbsd_pipes_broken async def test_run_with_broken_pipe() -> None: result = await run_process( - [sys.executable, "-c", "import sys; sys.stdin.close()"], stdin=b"x" * 131072 + [sys.executable, "-c", "import sys; sys.stdin.close()"], + stdin=b"x" * 131072, ) assert result.returncode == 0 assert result.stdout is result.stderr is None @@ -404,7 +460,9 @@ async def test_stderr_stdout(background_process: BackgroundProcessType) -> None: # this one hits the branch where stderr=STDOUT but stdout # is not redirected async with background_process( - CAT, stdin=subprocess.PIPE, stderr=subprocess.STDOUT + CAT, + stdin=subprocess.PIPE, + stderr=subprocess.STDOUT, ) as proc: assert proc.stdout is None assert proc.stderr is None @@ -452,7 +510,8 @@ async def test_errors() -> None: @background_process_param async def test_signals(background_process: BackgroundProcessType) -> None: async def test_one_signal( - send_it: Callable[[Process], None], signum: signal.Signals | None + send_it: Callable[[Process], None], + signum: signal.Signals | None, ) -> None: with move_on_after(1.0) as scope: async with background_process(SLEEP(3600)) as proc: @@ -500,6 +559,31 @@ async def test_wait_reapable_fails(background_process: BackgroundProcessType) -> signal.signal(signal.SIGCHLD, old_sigchld) +@pytest.mark.skipif(not posix, reason="POSIX specific") +@background_process_param +async def test_wait_reapable_fails_no_pidfd( + background_process: BackgroundProcessType, +) -> None: + if TYPE_CHECKING and sys.platform == "win32": + return + with mock.patch("trio._subprocess.can_try_pidfd_open", new=False): + old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN) + try: + # With SIGCHLD disabled, the wait() syscall will wait for the + # process to exit but then fail with ECHILD. Make sure we + # support this case as the stdlib subprocess module does. + async with background_process(SLEEP(3600)) as proc: + async with _core.open_nursery() as nursery: + nursery.start_soon(proc.wait) + await wait_all_tasks_blocked() + proc.kill() + nursery.cancel_scope.deadline = _core.current_time() + 1.0 + assert not nursery.cancel_scope.cancelled_caught + assert proc.returncode == 0 # exit status unknowable, so... + finally: + signal.signal(signal.SIGCHLD, old_sigchld) + + @slow def test_waitid_eintr() -> None: # This only matters on PyPy (where we're coding EINTR handling @@ -557,7 +641,7 @@ async def custom_deliver_cancel(proc: Process) -> None: async with _core.open_nursery() as nursery: nursery.start_soon( - partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel) + partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel), ) await wait_all_tasks_blocked() nursery.cancel_scope.cancel() @@ -573,7 +657,7 @@ async def custom_deliver_cancel(proc: Process) -> None: async def do_stuff() -> None: async with _core.open_nursery() as nursery: nursery.start_soon( - partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel) + partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel), ) await wait_all_tasks_blocked() nursery.cancel_scope.cancel() @@ -601,7 +685,8 @@ def broken_terminate(self: Process) -> NoReturn: @pytest.mark.skipif(not posix, reason="posix only") async def test_warn_on_cancel_SIGKILL_escalation( - autojump_clock: MockClock, monkeypatch: pytest.MonkeyPatch + autojump_clock: MockClock, + monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setattr(Process, "terminate", lambda *args: None) @@ -617,7 +702,9 @@ async def test_warn_on_cancel_SIGKILL_escalation( async def test_run_process_background_fail() -> None: with RaisesGroup(subprocess.CalledProcessError): async with _core.open_nursery() as nursery: - proc: Process = await nursery.start(run_process, EXIT_FALSE) + value = await nursery.start(run_process, EXIT_FALSE) + assert isinstance(value, Process) + proc: Process = value assert proc.returncode == 1 diff --git a/src/trio/_tests/test_sync.py b/src/trio/_tests/test_sync.py index e4d04202cb..77721226a3 100644 --- a/src/trio/_tests/test_sync.py +++ b/src/trio/_tests/test_sync.py @@ -1,11 +1,15 @@ from __future__ import annotations +import re import weakref from typing import TYPE_CHECKING, Callable, Union import pytest +from trio.testing import Matcher, RaisesGroup + from .. import _core +from .._core._parking_lot import GLOBAL_PARKING_LOT_BREAKER from .._sync import * from .._timeouts import sleep_forever from ..testing import assert_checkpoints, wait_all_tasks_blocked @@ -228,7 +232,7 @@ async def do_acquire(s: Semaphore) -> None: assert record == ["started", "finished"] -async def test_Semaphore_bounded() -> None: +def test_Semaphore_bounded() -> None: with pytest.raises(TypeError): Semaphore(1, max_value=1.0) # type: ignore[arg-type] with pytest.raises(ValueError, match="^max_values must be >= initial_value$"): @@ -494,7 +498,9 @@ def release(self) -> None: ] generic_lock_test = pytest.mark.parametrize( - "lock_factory", lock_factories, ids=lock_factory_names + "lock_factory", + lock_factories, + ids=lock_factory_names, ) LockLike: TypeAlias = Union[ @@ -584,3 +590,66 @@ async def lock_taker() -> None: await wait_all_tasks_blocked() assert record == ["started"] lock_like.release() + + +async def test_lock_acquire_unowned_lock() -> None: + """Test that trying to acquire a lock whose owner has exited raises an error. + see https://github.com/python-trio/trio/issues/3035 + """ + assert not GLOBAL_PARKING_LOT_BREAKER + lock = trio.Lock() + async with trio.open_nursery() as nursery: + nursery.start_soon(lock.acquire) + owner_str = re.escape(str(lock._lot.broken_by[0])) + with pytest.raises( + trio.BrokenResourceError, + match=f"^Owner of this lock exited without releasing: {owner_str}$", + ): + await lock.acquire() + assert not GLOBAL_PARKING_LOT_BREAKER + + +async def test_lock_multiple_acquire() -> None: + """Test for error if awaiting on a lock whose owner exits without releasing. + see https://github.com/python-trio/trio/issues/3035""" + assert not GLOBAL_PARKING_LOT_BREAKER + lock = trio.Lock() + with RaisesGroup( + Matcher( + trio.BrokenResourceError, + match="^Owner of this lock exited without releasing: ", + ), + ): + async with trio.open_nursery() as nursery: + nursery.start_soon(lock.acquire) + nursery.start_soon(lock.acquire) + assert not GLOBAL_PARKING_LOT_BREAKER + + +async def test_lock_handover() -> None: + assert not GLOBAL_PARKING_LOT_BREAKER + child_task: Task | None = None + lock = trio.Lock() + + # this task acquires the lock + lock.acquire_nowait() + assert GLOBAL_PARKING_LOT_BREAKER == { + _core.current_task(): [ + lock._lot, + ], + } + + async with trio.open_nursery() as nursery: + nursery.start_soon(lock.acquire) + await wait_all_tasks_blocked() + + # hand over the lock to the child task + lock.release() + + # check values, and get the identifier out of the dict for later check + assert len(GLOBAL_PARKING_LOT_BREAKER) == 1 + child_task = next(iter(GLOBAL_PARKING_LOT_BREAKER)) + assert GLOBAL_PARKING_LOT_BREAKER[child_task] == [lock._lot] + + assert lock._lot.broken_by == [child_task] + assert not GLOBAL_PARKING_LOT_BREAKER diff --git a/src/trio/_tests/test_testing.py b/src/trio/_tests/test_testing.py index 0f2778dc15..ab47213766 100644 --- a/src/trio/_tests/test_testing.py +++ b/src/trio/_tests/test_testing.py @@ -236,7 +236,7 @@ async def child(i: int) -> None: ################################################################ -async def test__assert_raises() -> None: +def test__assert_raises() -> None: with pytest.raises(AssertionError): with _assert_raises(RuntimeError): 1 + 1 # noqa: B018 # "useless expression" @@ -393,7 +393,9 @@ def close_hook() -> None: record.append("close_hook") mss2 = MemorySendStream( - send_all_hook, wait_send_all_might_not_block_hook, close_hook + send_all_hook, + wait_send_all_might_not_block_hook, + close_hook, ) assert mss2.send_all_hook is send_all_hook @@ -670,10 +672,13 @@ async def check(listener: SocketListener) -> None: def test_trio_test() -> None: async def busy_kitchen( - *, mock_clock: object, autojump_clock: object + *, + mock_clock: object, + autojump_clock: object, ) -> None: ... # pragma: no cover with pytest.raises(ValueError, match="^too many clocks spoil the broth!$"): trio_test(busy_kitchen)( - mock_clock=MockClock(), autojump_clock=MockClock(autojump_threshold=0) + mock_clock=MockClock(), + autojump_clock=MockClock(autojump_threshold=0), ) diff --git a/src/trio/_tests/test_testing_raisesgroup.py b/src/trio/_tests/test_testing_raisesgroup.py index 1e96d38e52..bb86d88646 100644 --- a/src/trio/_tests/test_testing_raisesgroup.py +++ b/src/trio/_tests/test_testing_raisesgroup.py @@ -3,7 +3,6 @@ import re import sys from types import TracebackType -from typing import Any import pytest @@ -22,10 +21,10 @@ def test_raises_group() -> None: with pytest.raises( ValueError, match=wrap_escape( - f'Invalid argument "{TypeError()!r}" must be exception type, Matcher, or RaisesGroup.' + f'Invalid argument "{TypeError()!r}" must be exception type, Matcher, or RaisesGroup.', ), ): - RaisesGroup(TypeError()) + RaisesGroup(TypeError()) # type: ignore[call-overload] with RaisesGroup(ValueError): raise ExceptionGroup("foo", (ValueError(),)) @@ -94,7 +93,8 @@ def test_flatten_subgroups() -> None: raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)): raise ExceptionGroup( - "", (ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),) + "", + (ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),), ) with pytest.raises(ExceptionGroup): with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)): @@ -116,7 +116,8 @@ def test_catch_unwrapped_exceptions() -> None: # expecting multiple unwrapped exceptions is not possible with pytest.raises( - ValueError, match="^You cannot specify multiple exceptions with" + ValueError, + match="^You cannot specify multiple exceptions with", ): RaisesGroup(SyntaxError, ValueError, allow_unwrapped=True) # type: ignore[call-overload] # if users want one of several exception types they need to use a Matcher @@ -233,7 +234,10 @@ def test_RaisesGroup_matches() -> None: def test_message() -> None: - def check_message(message: str, body: RaisesGroup[Any]) -> None: + def check_message( + message: str, + body: RaisesGroup[BaseException], + ) -> None: with pytest.raises( AssertionError, match=f"^DID NOT RAISE any exception, expected {re.escape(message)}$", @@ -245,7 +249,8 @@ def check_message(message: str, body: RaisesGroup[Any]) -> None: check_message("ExceptionGroup(ValueError)", RaisesGroup(ValueError)) # multiple exceptions check_message( - "ExceptionGroup(ValueError, ValueError)", RaisesGroup(ValueError, ValueError) + "ExceptionGroup(ValueError, ValueError)", + RaisesGroup(ValueError, ValueError), ) # nested check_message( @@ -265,7 +270,8 @@ def check_message(message: str, body: RaisesGroup[Any]) -> None: # BaseExceptionGroup check_message( - "BaseExceptionGroup(KeyboardInterrupt)", RaisesGroup(KeyboardInterrupt) + "BaseExceptionGroup(KeyboardInterrupt)", + RaisesGroup(KeyboardInterrupt), ) # BaseExceptionGroup with type inside Matcher check_message( @@ -286,7 +292,8 @@ def check_message(message: str, body: RaisesGroup[Any]) -> None: def test_matcher() -> None: with pytest.raises( - ValueError, match="^You must specify at least one parameter to match on.$" + ValueError, + match="^You must specify at least one parameter to match on.$", ): Matcher() # type: ignore[call-overload] with pytest.raises( @@ -346,9 +353,9 @@ def check_errno_is_5(e: OSError) -> bool: def test_matcher_tostring() -> None: assert str(Matcher(ValueError)) == "Matcher(ValueError)" assert str(Matcher(match="[a-z]")) == "Matcher(match='[a-z]')" - pattern_no_flags = re.compile("noflag", 0) + pattern_no_flags = re.compile(r"noflag", 0) assert str(Matcher(match=pattern_no_flags)) == "Matcher(match='noflag')" - pattern_flags = re.compile("noflag", re.IGNORECASE) + pattern_flags = re.compile(r"noflag", re.IGNORECASE) assert str(Matcher(match=pattern_flags)) == f"Matcher(match={pattern_flags!r})" assert ( str(Matcher(ValueError, match="re", check=bool)) @@ -367,12 +374,3 @@ def test__ExceptionInfo(monkeypatch: pytest.MonkeyPatch) -> None: assert excinfo.type is ExceptionGroup assert excinfo.value.exceptions[0].args == ("hello",) assert isinstance(excinfo.tb, TracebackType) - - -def test_deprecated_strict() -> None: - """`strict` has been replaced with `flatten_subgroups`""" - # parameter is not included in overloaded signatures at all - with pytest.deprecated_call(): - RaisesGroup(ValueError, strict=False) # type: ignore[call-overload] - with pytest.deprecated_call(): - RaisesGroup(ValueError, strict=True) # type: ignore[call-overload] diff --git a/src/trio/_tests/test_threads.py b/src/trio/_tests/test_threads.py index b4a5842ff0..75f9142d69 100644 --- a/src/trio/_tests/test_threads.py +++ b/src/trio/_tests/test_threads.py @@ -10,13 +10,7 @@ from functools import partial from typing import ( TYPE_CHECKING, - AsyncGenerator, - Awaitable, - Callable, - List, NoReturn, - Tuple, - Type, TypeVar, Union, ) @@ -48,18 +42,21 @@ from ..testing import wait_all_tasks_blocked if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Awaitable, Callable + from outcome import Outcome from ..lowlevel import Task -RecordType = List[Tuple[str, Union[threading.Thread, Type[BaseException]]]] +RecordType = list[tuple[str, Union[threading.Thread, type[BaseException]]]] T = TypeVar("T") async def test_do_in_trio_thread() -> None: trio_thread = threading.current_thread() - async def check_case( + # Explicit "Any" is not allowed + async def check_case( # type: ignore[misc] do_in_trio_thread: Callable[..., threading.Thread], fn: Callable[..., T | Awaitable[T]], expected: tuple[str, T], @@ -223,7 +220,7 @@ def f(name: str) -> Callable[[None], threading.Thread]: # test that you can set a custom name, and that it's reset afterwards async def test_thread_name(name: str) -> None: thread = await to_thread_run_sync(f(name), thread_name=name) - assert re.match("Trio thread [0-9]*", thread.name) + assert re.match(r"Trio thread [0-9]*", thread.name) await test_thread_name("") await test_thread_name("fobiedoo") @@ -304,7 +301,7 @@ async def test_thread_name(name: str, expected: str | None = None) -> None: os_thread_name = _get_thread_name(thread.ident) assert os_thread_name is not None, "should skip earlier if this is the case" - assert re.match("Trio thread [0-9]*", os_thread_name) + assert re.match(r"Trio thread [0-9]*", os_thread_name) await test_thread_name("") await test_thread_name("fobiedoo") @@ -313,7 +310,7 @@ async def test_thread_name(name: str, expected: str | None = None) -> None: await test_thread_name("๐Ÿ’™", expected="?") -async def test_has_pthread_setname_np() -> None: +def test_has_pthread_setname_np() -> None: from trio._core._thread_cache import get_os_thread_name_func k = get_os_thread_name_func() @@ -336,7 +333,8 @@ def g() -> NoReturn: raise ValueError(threading.current_thread()) with pytest.raises( - ValueError, match=r"^$" + ValueError, + match=r"^$", ) as excinfo: await to_thread_run_sync(g) print(excinfo.value.args) @@ -404,7 +402,8 @@ async def child(q: stdlib_queue.Queue[None], abandon_on_cancel: bool) -> None: # handled gracefully. (Requires that the thread result machinery be prepared # for call_soon to raise RunFinishedError.) def test_run_in_worker_thread_abandoned( - capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch + capfd: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) @@ -444,7 +443,9 @@ async def child() -> None: @pytest.mark.parametrize("cancel", [False, True]) @pytest.mark.parametrize("use_default_limiter", [False, True]) async def test_run_in_worker_thread_limiter( - MAX: int, cancel: bool, use_default_limiter: bool + MAX: int, + cancel: bool, + use_default_limiter: bool, ) -> None: # This test is a bit tricky. The goal is to make sure that if we set # limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever @@ -647,7 +648,7 @@ async def async_fn() -> None: # pragma: no cover trio_test_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar( - "trio_test_contextvar" + "trio_test_contextvar", ) @@ -879,7 +880,7 @@ async def agen(token: _core.TrioToken | None) -> AsyncGenerator[None, None]: with _core.CancelScope(shield=True): try: await to_thread_run_sync( - partial(from_thread_run, sleep, 0, trio_token=token) + partial(from_thread_run, sleep, 0, trio_token=token), ) except _core.RunFinishedError: record.append("finished") @@ -1072,7 +1073,7 @@ def f() -> None: # type: ignore[no-redef] # noqa: F811 assert q.get(timeout=1) == "Cancelled" -async def test_from_thread_check_cancelled_raises_in_foreign_threads() -> None: +def test_from_thread_check_cancelled_raises_in_foreign_threads() -> None: with pytest.raises(RuntimeError): from_thread_check_cancelled() q: stdlib_queue.Queue[Outcome[object]] = stdlib_queue.Queue() diff --git a/src/trio/_tests/test_timeouts.py b/src/trio/_tests/test_timeouts.py index 98c3d18def..052520b2d9 100644 --- a/src/trio/_tests/test_timeouts.py +++ b/src/trio/_tests/test_timeouts.py @@ -1,14 +1,30 @@ +from __future__ import annotations + import time -from typing import Awaitable, Callable, TypeVar +from typing import TYPE_CHECKING, Protocol, TypeVar import outcome import pytest +import trio + from .. import _core from .._core._tests.tutil import slow -from .._timeouts import * +from .._timeouts import ( + TooSlowError, + fail_after, + fail_at, + move_on_after, + move_on_at, + sleep, + sleep_forever, + sleep_until, +) from ..testing import assert_checkpoints +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + T = TypeVar("T") @@ -75,6 +91,68 @@ async def sleep_3() -> None: await check_takes_about(sleep_3, TARGET) +async def test_cannot_wake_sleep_forever() -> None: + # Test an error occurs if you manually wake sleep_forever(). + task = trio.lowlevel.current_task() + + async def wake_task() -> None: + await trio.lowlevel.checkpoint() + trio.lowlevel.reschedule(task, outcome.Value(None)) + + async with trio.open_nursery() as nursery: + nursery.start_soon(wake_task) + with pytest.raises(RuntimeError): + await trio.sleep_forever() + + +class TimeoutScope(Protocol): + def __call__(self, seconds: float, *, shield: bool) -> trio.CancelScope: ... + + +@pytest.mark.parametrize("scope", [move_on_after, fail_after]) +async def test_context_shields_from_outer(scope: TimeoutScope) -> None: + with _core.CancelScope() as outer, scope(TARGET, shield=True) as inner: + outer.cancel() + try: + await trio.lowlevel.checkpoint() + except trio.Cancelled: # pragma: no cover + pytest.fail("shield didn't work") + inner.shield = False + with pytest.raises(trio.Cancelled): + await trio.lowlevel.checkpoint() + + +@slow +async def test_move_on_after_moves_on_even_if_shielded() -> None: + async def task() -> None: + with _core.CancelScope() as outer, move_on_after(TARGET, shield=True): + outer.cancel() + # The outer scope is cancelled, but this task is protected by the + # shield, so it manages to get to sleep until deadline is met + await sleep_forever() + + await check_takes_about(task, TARGET) + + +@slow +async def test_fail_after_fails_even_if_shielded() -> None: + async def task() -> None: + # fmt: off + # Remove after 3.9 unsupported, black formats in a way that breaks if + # you do `-X oldparser` + with pytest.raises(TooSlowError), _core.CancelScope() as outer, fail_after( + TARGET, + shield=True, + ): + # fmt: on + outer.cancel() + # The outer scope is cancelled, but this task is protected by the + # shield, so it manages to get to sleep until deadline is met + await sleep_forever() + + await check_takes_about(task, TARGET) + + @slow async def test_fail() -> None: async def sleep_4() -> None: @@ -111,7 +189,7 @@ async def test_timeouts_raise_value_error() -> None: ): with pytest.raises( ValueError, - match="^(duration|deadline|timeout) must (not )*be (non-negative|NaN)$", + match="^(deadline|`seconds`) must (not )*be (non-negative|NaN)$", ): await fun(val) @@ -125,7 +203,79 @@ async def test_timeouts_raise_value_error() -> None: ): with pytest.raises( ValueError, - match="^(duration|deadline|timeout) must (not )*be (non-negative|NaN)$", + match="^(deadline|`seconds`) must (not )*be (non-negative|NaN)$", ): with cm(val): pass # pragma: no cover + + +async def test_timeout_deadline_on_entry(mock_clock: _core.MockClock) -> None: + rcs = move_on_after(5) + assert rcs.relative_deadline == 5 + + mock_clock.jump(3) + start = _core.current_time() + with rcs as cs: + assert cs.is_relative is None + + # This would previously be start+2 + assert cs.deadline == start + 5 + assert cs.relative_deadline == 5 + + cs.deadline = start + 3 + assert cs.deadline == start + 3 + assert cs.relative_deadline == 3 + + cs.relative_deadline = 4 + assert cs.deadline == start + 4 + assert cs.relative_deadline == 4 + + rcs = move_on_after(5) + assert rcs.shield is False + rcs.shield = True + assert rcs.shield is True + + mock_clock.jump(3) + start = _core.current_time() + with rcs as cs: + assert cs.deadline == start + 5 + + assert rcs is cs + + +async def test_invalid_access_unentered(mock_clock: _core.MockClock) -> None: + cs = move_on_after(5) + mock_clock.jump(3) + start = _core.current_time() + + match_str = "^unentered relative cancel scope does not have an absolute deadline" + with pytest.warns(DeprecationWarning, match=match_str): + assert cs.deadline == start + 5 + mock_clock.jump(1) + # this is hella sketchy, but they *have* been warned + with pytest.warns(DeprecationWarning, match=match_str): + assert cs.deadline == start + 6 + + with pytest.warns(DeprecationWarning, match=match_str): + cs.deadline = 7 + # now transformed into absolute + assert cs.deadline == 7 + assert not cs.is_relative + + cs = move_on_at(5) + + match_str = ( + "^unentered non-relative cancel scope does not have a relative deadline$" + ) + with pytest.raises(RuntimeError, match=match_str): + assert cs.relative_deadline + with pytest.raises(RuntimeError, match=match_str): + cs.relative_deadline = 7 + + +@pytest.mark.xfail(reason="not implemented") +async def test_fail_access_before_entering() -> None: # pragma: no cover + my_fail_at = fail_at(5) + assert my_fail_at.deadline # type: ignore[attr-defined] + my_fail_after = fail_after(5) + assert my_fail_after.relative_deadline # type: ignore[attr-defined] diff --git a/src/trio/_tests/test_tracing.py b/src/trio/_tests/test_tracing.py index 5cf758c6b6..52ea9bfa40 100644 --- a/src/trio/_tests/test_tracing.py +++ b/src/trio/_tests/test_tracing.py @@ -1,7 +1,12 @@ -from typing import AsyncGenerator +from __future__ import annotations + +from typing import TYPE_CHECKING import trio +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + async def coro1(event: trio.Event) -> None: event.set() diff --git a/src/trio/_tests/test_trio.py b/src/trio/_tests/test_trio.py new file mode 100644 index 0000000000..65d4ce34f2 --- /dev/null +++ b/src/trio/_tests/test_trio.py @@ -0,0 +1,8 @@ +def test_trio_import() -> None: + import sys + + for module in list(sys.modules.keys()): + if module.startswith("trio"): + del sys.modules[module] + + import trio # noqa: F401 diff --git a/src/trio/_tests/test_unix_pipes.py b/src/trio/_tests/test_unix_pipes.py index 6f8fa6e02e..c850ebefea 100644 --- a/src/trio/_tests/test_unix_pipes.py +++ b/src/trio/_tests/test_unix_pipes.py @@ -12,6 +12,9 @@ from .._core._tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken from ..testing import check_one_way_stream, wait_all_tasks_blocked +if TYPE_CHECKING: + from .._file_io import _HasFileNo + posix = os.name == "posix" pytestmark = pytest.mark.skipif(not posix, reason="posix only") @@ -30,7 +33,7 @@ async def make_pipe() -> tuple[FdStream, FdStream]: return FdStream(w), FdStream(r) -async def make_clogged_pipe(): +async def make_clogged_pipe() -> tuple[FdStream, FdStream]: s, r = await make_pipe() try: while True: @@ -197,8 +200,11 @@ async def expect_closedresourceerror() -> None: orig_wait_readable = _core._run.TheIOManager.wait_readable - async def patched_wait_readable(*args, **kwargs) -> None: - await orig_wait_readable(*args, **kwargs) + async def patched_wait_readable( + self: _core._run.TheIOManager, + fd: int | _HasFileNo, + ) -> None: + await orig_wait_readable(self, fd) await r.aclose() monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable) @@ -225,8 +231,11 @@ async def expect_closedresourceerror() -> None: orig_wait_writable = _core._run.TheIOManager.wait_writable - async def patched_wait_writable(*args, **kwargs) -> None: - await orig_wait_writable(*args, **kwargs) + async def patched_wait_writable( + self: _core._run.TheIOManager, + fd: int | _HasFileNo, + ) -> None: + await orig_wait_writable(self, fd) await s.aclose() monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable) diff --git a/src/trio/_tests/test_util.py b/src/trio/_tests/test_util.py index 3e62eb622e..5036d76e52 100644 --- a/src/trio/_tests/test_util.py +++ b/src/trio/_tests/test_util.py @@ -1,8 +1,11 @@ -import signal +from __future__ import annotations + import sys import types -from typing import Any, TypeVar +from typing import TYPE_CHECKING, TypeVar +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Coroutine, Generator import pytest import trio @@ -21,25 +24,13 @@ fixup_module_metadata, generic_function, is_main_thread, - signal_raise, ) from ..testing import wait_all_tasks_blocked -T = TypeVar("T") - - -def test_signal_raise() -> None: - record = [] +if TYPE_CHECKING: + from collections.abc import AsyncGenerator - def handler(signum: int, _: object) -> None: - record.append(signum) - - old = signal.signal(signal.SIGFPE, handler) - try: - signal_raise(signal.SIGFPE) - finally: - signal.signal(signal.SIGFPE, old) - assert record == [signal.SIGFPE] +T = TypeVar("T") async def test_ConflictDetector() -> None: @@ -116,9 +107,11 @@ async def f() -> None: # pragma: no cover import asyncio if sys.version_info < (3, 11): - # not bothering to type this one - @asyncio.coroutine # type: ignore[misc] - def generator_based_coro() -> Any: # pragma: no cover + + @asyncio.coroutine + def generator_based_coro() -> ( + Generator[Coroutine[None, None, None], None, None] + ): # pragma: no cover yield from asyncio.sleep(1) with pytest.raises(TypeError) as excinfo: @@ -147,12 +140,13 @@ def generator_based_coro() -> Any: # pragma: no cover assert "appears to be synchronous" in str(excinfo.value) - async def async_gen(_: object) -> Any: # pragma: no cover + async def async_gen( + _: object, + ) -> AsyncGenerator[None, None]: # pragma: no cover yield - # does not give arg-type typing error with pytest.raises(TypeError) as excinfo: - coroutine_or_error(async_gen, [0]) # type: ignore[unused-coroutine] + coroutine_or_error(async_gen, [0]) # type: ignore[arg-type,unused-coroutine] msg = "expected an async function but got an async generator" assert msg in str(excinfo.value) @@ -198,13 +192,13 @@ class SpecialClass(metaclass=NoPublicConstructor): def __init__(self, a: int, b: float) -> None: """Check arguments can be passed to __init__.""" assert a == 8 - assert b == 3.14 + assert b == 3.15 with pytest.raises(TypeError): - SpecialClass(8, 3.14) + SpecialClass(8, 3.15) # Private constructor should not raise, and passes args to __init__. - assert isinstance(SpecialClass._create(8, b=3.14), SpecialClass) + assert isinstance(SpecialClass._create(8, b=3.15), SpecialClass) def test_fixup_module_metadata() -> None: diff --git a/src/trio/_tests/test_wait_for_object.py b/src/trio/_tests/test_wait_for_object.py index 54bbb77567..7a3472ba3a 100644 --- a/src/trio/_tests/test_wait_for_object.py +++ b/src/trio/_tests/test_wait_for_object.py @@ -16,7 +16,7 @@ from .._wait_for_object import WaitForMultipleObjects_sync, WaitForSingleObject -async def test_WaitForMultipleObjects_sync() -> None: +def test_WaitForMultipleObjects_sync() -> None: # This does a series of tests where we set/close the handle before # initiating the waiting for it. # @@ -81,7 +81,9 @@ async def test_WaitForMultipleObjects_sync_slow() -> None: t0 = _core.current_time() async with _core.open_nursery() as nursery: nursery.start_soon( - trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1 + trio.to_thread.run_sync, + WaitForMultipleObjects_sync, + handle1, ) await _timeouts.sleep(TIMEOUT) # If we would comment the line below, the above thread will be stuck, @@ -98,7 +100,10 @@ async def test_WaitForMultipleObjects_sync_slow() -> None: t0 = _core.current_time() async with _core.open_nursery() as nursery: nursery.start_soon( - trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2 + trio.to_thread.run_sync, + WaitForMultipleObjects_sync, + handle1, + handle2, ) await _timeouts.sleep(TIMEOUT) kernel32.SetEvent(handle1) @@ -114,7 +119,10 @@ async def test_WaitForMultipleObjects_sync_slow() -> None: t0 = _core.current_time() async with _core.open_nursery() as nursery: nursery.start_soon( - trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2 + trio.to_thread.run_sync, + WaitForMultipleObjects_sync, + handle1, + handle2, ) await _timeouts.sleep(TIMEOUT) kernel32.SetEvent(handle2) diff --git a/src/trio/_tests/test_windows_pipes.py b/src/trio/_tests/test_windows_pipes.py index 38a25cdc54..e42736d65d 100644 --- a/src/trio/_tests/test_windows_pipes.py +++ b/src/trio/_tests/test_windows_pipes.py @@ -28,7 +28,7 @@ async def make_pipe() -> tuple[PipeSendStream, PipeReceiveStream]: return PipeSendStream(w), PipeReceiveStream(r) -async def test_pipe_typecheck() -> None: +def test_pipe_typecheck() -> None: with pytest.raises(TypeError): PipeSendStream(1.0) # type: ignore[arg-type] with pytest.raises(TypeError): @@ -92,7 +92,7 @@ async def test_async_with() -> None: async def test_close_during_write() -> None: - w, r = await make_pipe() + w, _r = await make_pipe() async with _core.open_nursery() as nursery: async def write_forever() -> None: diff --git a/src/trio/_tests/tools/test_gen_exports.py b/src/trio/_tests/tools/test_gen_exports.py index 19158451f7..669df968e0 100644 --- a/src/trio/_tests/tools/test_gen_exports.py +++ b/src/trio/_tests/tools/test_gen_exports.py @@ -91,7 +91,11 @@ def test_create_pass_through_args() -> None: @skip_lints @pytest.mark.parametrize("imports", [IMPORT_1, IMPORT_2, IMPORT_3]) -def test_process(tmp_path: Path, imports: str) -> None: +def test_process( + tmp_path: Path, + imports: str, + capsys: pytest.CaptureFixture[str], +) -> None: try: import black # noqa: F401 # there's no dedicated CI run that has astor+isort, but lacks black. @@ -106,7 +110,13 @@ def test_process(tmp_path: Path, imports: str) -> None: with pytest.raises(SystemExit) as excinfo: process([file], do_test=True) assert excinfo.value.code == 1 - process([file], do_test=False) + captured = capsys.readouterr() + assert "Generated sources are outdated. Please regenerate." in captured.out + with pytest.raises(SystemExit) as excinfo: + process([file], do_test=False) + assert excinfo.value.code == 1 + captured = capsys.readouterr() + assert "Regenerated sources successfully." in captured.out assert genpath.exists() process([file], do_test=True) # But if we change the lookup path it notices diff --git a/src/trio/_tests/tools/test_mypy_annotate.py b/src/trio/_tests/tools/test_mypy_annotate.py index 0ff4babb99..09a57ce745 100644 --- a/src/trio/_tests/tools/test_mypy_annotate.py +++ b/src/trio/_tests/tools/test_mypy_annotate.py @@ -105,6 +105,8 @@ def test_endtoend( monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str], ) -> None: + import trio._tools.mypy_annotate as mypy_annotate + inp_text = """\ Mypy begun trio/core.py:15: error: Bad types here [misc] @@ -116,7 +118,9 @@ def test_endtoend( with monkeypatch.context(): monkeypatch.setattr(sys, "stdin", io.StringIO(inp_text)) - main(["--dumpfile", str(result_file), "--platform", "SomePlatform"]) + mypy_annotate.main( + ["--dumpfile", str(result_file), "--platform", "SomePlatform"], + ) std = capsys.readouterr() assert std.err == "" diff --git a/src/trio/_tests/type_tests/path.py b/src/trio/_tests/type_tests/path.py index 15d25ae954..6749d06276 100644 --- a/src/trio/_tests/type_tests/path.py +++ b/src/trio/_tests/type_tests/path.py @@ -4,7 +4,7 @@ import os import pathlib import sys -from typing import IO, Any, BinaryIO, List, Tuple +from typing import IO, Any, BinaryIO import trio from trio._file_io import AsyncIOWrapper @@ -35,7 +35,7 @@ def operator_checks(text: str, tpath: trio.Path, ppath: pathlib.Path) -> None: def sync_attrs(path: trio.Path) -> None: - assert_type(path.parts, Tuple[str, ...]) + assert_type(path.parts, tuple[str, ...]) assert_type(path.drive, str) assert_type(path.root, str) assert_type(path.anchor, str) @@ -43,22 +43,20 @@ def sync_attrs(path: trio.Path) -> None: assert_type(path.parent, trio.Path) assert_type(path.name, str) assert_type(path.suffix, str) - assert_type(path.suffixes, List[str]) + assert_type(path.suffixes, list[str]) assert_type(path.stem, str) assert_type(path.as_posix(), str) assert_type(path.as_uri(), str) assert_type(path.is_absolute(), bool) - if sys.version_info > (3, 9): - assert_type(path.is_relative_to(path), bool) + assert_type(path.is_relative_to(path), bool) assert_type(path.is_reserved(), bool) assert_type(path.joinpath(path, "folder"), trio.Path) assert_type(path.match("*.py"), bool) assert_type(path.relative_to("/usr"), trio.Path) - if sys.version_info > (3, 12): - assert_type(path.relative_to("/", walk_up=True), bool) + if sys.version_info >= (3, 12): + assert_type(path.relative_to("/", walk_up=True), trio.Path) assert_type(path.with_name("filename.txt"), trio.Path) - if sys.version_info > (3, 9): - assert_type(path.with_stem("readme"), trio.Path) + assert_type(path.with_stem("readme"), trio.Path) assert_type(path.with_suffix(".log"), trio.Path) @@ -75,7 +73,7 @@ async def async_attrs(path: trio.Path) -> None: assert_type(await path.group(), str) assert_type(await path.is_dir(), bool) assert_type(await path.is_file(), bool) - if sys.version_info > (3, 12): + if sys.version_info >= (3, 12): assert_type(await path.is_junction(), bool) if sys.platform != "win32": assert_type(await path.is_mount(), bool) @@ -95,8 +93,7 @@ async def async_attrs(path: trio.Path) -> None: assert_type(await path.owner(), str) assert_type(await path.read_bytes(), bytes) assert_type(await path.read_text(encoding="utf16", errors="replace"), str) - if sys.version_info > (3, 9): - assert_type(await path.readlink(), trio.Path) + assert_type(await path.readlink(), trio.Path) assert_type(await path.rename("another"), trio.Path) assert_type(await path.replace(path), trio.Path) assert_type(await path.resolve(), trio.Path) @@ -107,13 +104,14 @@ async def async_attrs(path: trio.Path) -> None: assert_type(await path.rmdir(), None) assert_type(await path.samefile("something_else"), bool) assert_type(await path.symlink_to("somewhere"), None) - if sys.version_info > (3, 10): + if sys.version_info >= (3, 10): assert_type(await path.hardlink_to("elsewhere"), None) assert_type(await path.touch(), None) assert_type(await path.unlink(missing_ok=True), None) assert_type(await path.write_bytes(b"123"), int) assert_type( - await path.write_text("hello", encoding="utf32le", errors="ignore"), int + await path.write_text("hello", encoding="utf32le", errors="ignore"), + int, ) @@ -140,4 +138,4 @@ async def open_results(path: trio.Path, some_int: int, some_str: str) -> None: assert_type(await file_text.read(), str) assert_type(await file_text.write("test"), int) # TODO: report mypy bug: equiv to https://github.com/microsoft/pyright/issues/6833 - assert_type(await file_text.readlines(), List[str]) + assert_type(await file_text.readlines(), list[str]) diff --git a/src/trio/_tests/type_tests/raisesgroup.py b/src/trio/_tests/type_tests/raisesgroup.py index fe4053ebc5..4d5ed4882c 100644 --- a/src/trio/_tests/type_tests/raisesgroup.py +++ b/src/trio/_tests/type_tests/raisesgroup.py @@ -1,21 +1,7 @@ -"""The typing of RaisesGroup involves a lot of deception and lies, since AFAIK what we -actually want to achieve is ~impossible. This is because we specify what we expect with -instances of RaisesGroup and exception classes, but excinfo.value will be instances of -[Base]ExceptionGroup and instances of exceptions. So we need to "translate" from -RaisesGroup to ExceptionGroup. - -The way it currently works is that RaisesGroup[E] corresponds to -ExceptionInfo[BaseExceptionGroup[E]], so the top-level group will be correct. But -RaisesGroup[RaisesGroup[ValueError]] will become -ExceptionInfo[BaseExceptionGroup[RaisesGroup[ValueError]]]. To get around that we specify -RaisesGroup as a subclass of BaseExceptionGroup during type checking - which should mean -that most static type checking for end users should be mostly correct. -""" - from __future__ import annotations import sys -from typing import Union +from typing import Callable, Union from trio.testing import Matcher, RaisesGroup from typing_extensions import assert_type @@ -26,17 +12,6 @@ # split into functions to isolate the different scopes -def check_inheritance_and_assignments() -> None: - # Check inheritance - _: BaseExceptionGroup[ValueError] = RaisesGroup(ValueError) - _ = RaisesGroup(RaisesGroup(ValueError)) # type: ignore - - a: BaseExceptionGroup[BaseExceptionGroup[ValueError]] - a = RaisesGroup(RaisesGroup(ValueError)) - a = BaseExceptionGroup("", (BaseExceptionGroup("", (ValueError(),)),)) - assert a - - def check_matcher_typevar_default(e: Matcher) -> object: assert e.exception_type is not None exc: type[BaseException] = e.exception_type @@ -46,28 +21,32 @@ def check_matcher_typevar_default(e: Matcher) -> object: def check_basic_contextmanager() -> None: - # One level of Group is correctly translated - except it's a BaseExceptionGroup - # instead of an ExceptionGroup. with RaisesGroup(ValueError) as e: raise ExceptionGroup("foo", (ValueError(),)) - assert_type(e.value, BaseExceptionGroup[ValueError]) + assert_type(e.value, ExceptionGroup[ValueError]) def check_basic_matches() -> None: # check that matches gets rid of the naked ValueError in the union exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup("", (ValueError(),)) if RaisesGroup(ValueError).matches(exc): - assert_type(exc, BaseExceptionGroup[ValueError]) + assert_type(exc, ExceptionGroup[ValueError]) + + # also check that BaseExceptionGroup shows up for BaseExceptions + if RaisesGroup(KeyboardInterrupt).matches(exc): + assert_type(exc, BaseExceptionGroup[KeyboardInterrupt]) def check_matches_with_different_exception_type() -> None: - # This should probably raise some type error somewhere, since - # ValueError != KeyboardInterrupt e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup( - "", (KeyboardInterrupt(),) + "", + (KeyboardInterrupt(),), ) + + # note: it might be tempting to have this warn. + # however, that isn't possible with current typing if RaisesGroup(ValueError).matches(e): - assert_type(e, BaseExceptionGroup[ValueError]) + assert_type(e, ExceptionGroup[ValueError]) def check_matcher_init() -> None: @@ -133,16 +112,7 @@ def handle_value(e: BaseExceptionGroup[ValueError]) -> bool: def raisesgroup_narrow_baseexceptiongroup() -> None: - """Check type narrowing specifically for the container exceptiongroup. - This is not currently working, and after playing around with it for a bit - I think the only way is to introduce a subclass `NonBaseRaisesGroup`, and overload - `__new__` in Raisesgroup to return the subclass when exceptions are non-base. - (or make current class BaseRaisesGroup and introduce RaisesGroup for non-base) - I encountered problems trying to type this though, see - https://github.com/python/mypy/issues/17251 - That is probably possible to work around by entirely using `__new__` instead of - `__init__`, but........ ugh. - """ + """Check type narrowing specifically for the container exceptiongroup.""" def handle_group(e: ExceptionGroup[Exception]) -> bool: return True @@ -150,53 +120,48 @@ def handle_group(e: ExceptionGroup[Exception]) -> bool: def handle_group_value(e: ExceptionGroup[ValueError]) -> bool: return True - # should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup - RaisesGroup(ValueError, check=handle_group_value) # type: ignore + RaisesGroup(ValueError, check=handle_group_value) - # should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup - RaisesGroup(Exception, check=handle_group) # type: ignore + RaisesGroup(Exception, check=handle_group) def check_matcher_transparent() -> None: with RaisesGroup(Matcher(ValueError)) as e: ... _: BaseExceptionGroup[ValueError] = e.value - assert_type(e.value, BaseExceptionGroup[ValueError]) + assert_type(e.value, ExceptionGroup[ValueError]) def check_nested_raisesgroups_contextmanager() -> None: with RaisesGroup(RaisesGroup(ValueError)) as excinfo: raise ExceptionGroup("foo", (ValueError(),)) - # thanks to inheritance this assignment works _: BaseExceptionGroup[BaseExceptionGroup[ValueError]] = excinfo.value - # and it can mostly be treated like an exceptiongroup - print(excinfo.value.exceptions[0].exceptions[0]) - # but assert_type reveals the lies - print(type(excinfo.value)) # would print "ExceptionGroup" - # typing says it's a BaseExceptionGroup assert_type( excinfo.value, - BaseExceptionGroup[RaisesGroup[ValueError]], + ExceptionGroup[ExceptionGroup[ValueError]], ) - print(type(excinfo.value.exceptions[0])) # would print "ExceptionGroup" - # but type checkers are utterly confused assert_type( excinfo.value.exceptions[0], - Union[RaisesGroup[ValueError], BaseExceptionGroup[RaisesGroup[ValueError]]], + # this union is because of how typeshed defines .exceptions + Union[ + ExceptionGroup[ValueError], + ExceptionGroup[ExceptionGroup[ValueError]], + ], ) def check_nested_raisesgroups_matches() -> None: """Check nested RaisesGroups with .matches""" exc: ExceptionGroup[ExceptionGroup[ValueError]] = ExceptionGroup( - "", (ExceptionGroup("", (ValueError(),)),) + "", + (ExceptionGroup("", (ValueError(),)),), ) - # has the same problems as check_nested_raisesgroups_contextmanager + if RaisesGroup(RaisesGroup(ValueError)).matches(exc): - assert_type(exc, BaseExceptionGroup[RaisesGroup[ValueError]]) + assert_type(exc, ExceptionGroup[ExceptionGroup[ValueError]]) def check_multiple_exceptions_1() -> None: @@ -204,7 +169,7 @@ def check_multiple_exceptions_1() -> None: b = RaisesGroup(Matcher(ValueError), Matcher(ValueError)) c = RaisesGroup(ValueError, Matcher(ValueError)) - d: BaseExceptionGroup[ValueError] + d: RaisesGroup[ValueError] d = a d = b d = c @@ -217,7 +182,7 @@ def check_multiple_exceptions_2() -> None: b = RaisesGroup(Matcher(ValueError), TypeError) c = RaisesGroup(ValueError, TypeError) - d: BaseExceptionGroup[Exception] + d: RaisesGroup[Exception] d = a d = b d = c @@ -250,3 +215,25 @@ def check_raisesgroup_overloads() -> None: # if they're both false we can of course specify nested raisesgroup RaisesGroup(RaisesGroup(ValueError)) + + +def check_triple_nested_raisesgroup() -> None: + with RaisesGroup(RaisesGroup(RaisesGroup(ValueError))) as e: + assert_type(e.value, ExceptionGroup[ExceptionGroup[ExceptionGroup[ValueError]]]) + + +def check_check_typing() -> None: + # mypy issue is https://github.com/python/mypy/issues/18185 + + # fmt: off + # mypy raises an error on `assert_type` + # pyright raises an error on `RaisesGroup(ValueError).check` + # to satisfy both, need to disable formatting and put it on one line + assert_type(RaisesGroup(ValueError).check, # type: ignore + Union[ + Callable[[BaseExceptionGroup[ValueError]], None], + Callable[[ExceptionGroup[ValueError]], None], + None, + ], + ) + # fmt: on diff --git a/src/trio/_tests/type_tests/task_status.py b/src/trio/_tests/type_tests/task_status.py index 90cfc6957f..6d2e3922bb 100644 --- a/src/trio/_tests/type_tests/task_status.py +++ b/src/trio/_tests/type_tests/task_status.py @@ -4,7 +4,7 @@ from typing_extensions import assert_type -async def check_status( +def check_status( none_status_explicit: TaskStatus[None], none_status_implicit: TaskStatus, int_status: TaskStatus[int], diff --git a/src/trio/_threads.py b/src/trio/_threads.py index a04b737292..7afd7b612a 100644 --- a/src/trio/_threads.py +++ b/src/trio/_threads.py @@ -139,21 +139,24 @@ def current_default_thread_limiter() -> CapacityLimiter: # system; see https://github.com/python-trio/trio/issues/182 # But for now we just need an object to stand in for the thread, so we can # keep track of who's holding the CapacityLimiter's token. -@attrs.frozen(eq=False, hash=False, slots=False) +@attrs.frozen(eq=False, slots=False) class ThreadPlaceholder: name: str # Types for the to_thread_run_sync message loop @attrs.frozen(eq=False, slots=False) -class Run(Generic[RetT]): - afn: Callable[..., Awaitable[RetT]] +# Explicit .../"Any" is not allowed +class Run(Generic[RetT]): # type: ignore[misc] + afn: Callable[..., Awaitable[RetT]] # type: ignore[misc] args: tuple[object, ...] context: contextvars.Context = attrs.field( - init=False, factory=contextvars.copy_context + init=False, + factory=contextvars.copy_context, ) queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attrs.field( - init=False, factory=stdlib_queue.SimpleQueue + init=False, + factory=stdlib_queue.SimpleQueue, ) @disable_ki_protection @@ -190,25 +193,30 @@ def run_in_system_nursery(self, token: TrioToken) -> None: def in_trio_thread() -> None: try: trio.lowlevel.spawn_system_task( - self.run_system, name=self.afn, context=self.context + self.run_system, + name=self.afn, + context=self.context, ) except RuntimeError: # system nursery is closed self.queue.put_nowait( - outcome.Error(trio.RunFinishedError("system nursery is closed")) + outcome.Error(trio.RunFinishedError("system nursery is closed")), ) token.run_sync_soon(in_trio_thread) @attrs.frozen(eq=False, slots=False) -class RunSync(Generic[RetT]): - fn: Callable[..., RetT] +# Explicit .../"Any" is not allowed +class RunSync(Generic[RetT]): # type: ignore[misc] + fn: Callable[..., RetT] # type: ignore[misc] args: tuple[object, ...] context: contextvars.Context = attrs.field( - init=False, factory=contextvars.copy_context + init=False, + factory=contextvars.copy_context, ) queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attrs.field( - init=False, factory=stdlib_queue.SimpleQueue + init=False, + factory=stdlib_queue.SimpleQueue, ) @disable_ki_protection @@ -220,7 +228,7 @@ def unprotected_fn(self) -> RetT: ret.close() raise TypeError( "Trio expected a synchronous function, but {!r} appears to be " - "asynchronous".format(getattr(self.fn, "__qualname__", self.fn)) + "asynchronous".format(getattr(self.fn, "__qualname__", self.fn)), ) return ret @@ -386,7 +394,7 @@ def worker_fn() -> RetT: ret.close() raise TypeError( "Trio expected a sync function, but {!r} appears to be " - "asynchronous".format(getattr(sync_fn, "__qualname__", sync_fn)) + "asynchronous".format(getattr(sync_fn, "__qualname__", sync_fn)), ) return ret @@ -441,7 +449,7 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: msg_from_thread.run_sync() else: # pragma: no cover, internal debugging guard TODO: use assert_never raise TypeError( - f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}." + f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}.", ) del msg_from_thread @@ -477,14 +485,15 @@ def from_thread_check_cancelled() -> None: raise_cancel = PARENT_TASK_DATA.cancel_register[0] except AttributeError: raise RuntimeError( - "this thread wasn't created by Trio, can't check for cancellation" + "this thread wasn't created by Trio, can't check for cancellation", ) from None if raise_cancel is not None: raise_cancel() def _send_message_to_trio( - trio_token: TrioToken | None, message_to_trio: Run[RetT] | RunSync[RetT] + trio_token: TrioToken | None, + message_to_trio: Run[RetT] | RunSync[RetT], ) -> RetT: """Shared logic of from_thread functions""" token_provided = trio_token is not None @@ -494,7 +503,7 @@ def _send_message_to_trio( trio_token = PARENT_TASK_DATA.token except AttributeError: raise RuntimeError( - "this thread wasn't created by Trio, pass kwarg trio_token=..." + "this thread wasn't created by Trio, pass kwarg trio_token=...", ) from None elif not isinstance(trio_token, TrioToken): raise RuntimeError("Passed kwarg trio_token is not of type TrioToken") @@ -515,7 +524,8 @@ def _send_message_to_trio( return message_to_trio.queue.get().unwrap() -def from_thread_run( +# Explicit "Any" is not allowed +def from_thread_run( # type: ignore[misc] afn: Callable[..., Awaitable[RetT]], *args: object, trio_token: TrioToken | None = None, @@ -559,7 +569,8 @@ def from_thread_run( return _send_message_to_trio(trio_token, Run(afn, args)) -def from_thread_run_sync( +# Explicit "Any" is not allowed +def from_thread_run_sync( # type: ignore[misc] fn: Callable[..., RetT], *args: object, trio_token: TrioToken | None = None, diff --git a/src/trio/_timeouts.py b/src/trio/_timeouts.py index 1d03b2f2e3..7ce123c7c5 100644 --- a/src/trio/_timeouts.py +++ b/src/trio/_timeouts.py @@ -1,51 +1,75 @@ from __future__ import annotations import math -from contextlib import AbstractContextManager, contextmanager -from typing import TYPE_CHECKING +import sys +from contextlib import contextmanager +from typing import TYPE_CHECKING, NoReturn import trio +if TYPE_CHECKING: + from collections.abc import Generator -def move_on_at(deadline: float) -> trio.CancelScope: + +def move_on_at(deadline: float, *, shield: bool = False) -> trio.CancelScope: """Use as a context manager to create a cancel scope with the given absolute deadline. Args: deadline (float): The deadline. + shield (bool): Initial value for the `~trio.CancelScope.shield` attribute + of the newly created cancel scope. Raises: ValueError: if deadline is NaN. """ - if math.isnan(deadline): - raise ValueError("deadline must not be NaN") - return trio.CancelScope(deadline=deadline) + # CancelScope validates that deadline isn't math.nan + return trio.CancelScope(deadline=deadline, shield=shield) -def move_on_after(seconds: float) -> trio.CancelScope: +def move_on_after( + seconds: float, + *, + shield: bool = False, +) -> trio.CancelScope: """Use as a context manager to create a cancel scope whose deadline is set to now + *seconds*. + The deadline of the cancel scope is calculated upon entering. + Args: seconds (float): The timeout. + shield (bool): Initial value for the `~trio.CancelScope.shield` attribute + of the newly created cancel scope. Raises: - ValueError: if timeout is less than zero or NaN. + ValueError: if ``seconds`` is less than zero or NaN. """ + # duplicate validation logic to have the correct parameter name if seconds < 0: - raise ValueError("timeout must be non-negative") - return move_on_at(trio.current_time() + seconds) + raise ValueError("`seconds` must be non-negative") + if math.isnan(seconds): + raise ValueError("`seconds` must not be NaN") + return trio.CancelScope( + shield=shield, + relative_deadline=seconds, + ) -async def sleep_forever() -> None: +async def sleep_forever() -> NoReturn: """Pause execution of the current task forever (or until cancelled). - Equivalent to calling ``await sleep(math.inf)``. + Equivalent to calling ``await sleep(math.inf)``, except that if manually + rescheduled this will raise a `RuntimeError`. + + Raises: + RuntimeError: if rescheduled """ await trio.lowlevel.wait_task_rescheduled(lambda _: trio.lowlevel.Abort.SUCCEEDED) + raise RuntimeError("Should never have been rescheduled!") async def sleep_until(deadline: float) -> None: @@ -80,7 +104,7 @@ async def sleep(seconds: float) -> None: """ if seconds < 0: - raise ValueError("duration must be non-negative") + raise ValueError("`seconds` must be non-negative") if seconds == 0: await trio.lowlevel.checkpoint() else: @@ -94,9 +118,12 @@ class TooSlowError(Exception): """ -# workaround for PyCharm not being able to infer return type from @contextmanager -# see https://youtrack.jetbrains.com/issue/PY-36444/PyCharm-doesnt-infer-types-when-using-contextlib.contextmanager-decorator -def fail_at(deadline: float) -> AbstractContextManager[trio.CancelScope]: # type: ignore[misc] +@contextmanager +def fail_at( + deadline: float, + *, + shield: bool = False, +) -> Generator[trio.CancelScope, None, None]: """Creates a cancel scope with the given deadline, and raises an error if it is actually cancelled. @@ -110,6 +137,8 @@ def fail_at(deadline: float) -> AbstractContextManager[trio.CancelScope]: # typ Args: deadline (float): The deadline. + shield (bool): Initial value for the `~trio.CancelScope.shield` attribute + of the newly created cancel scope. Raises: TooSlowError: if a :exc:`Cancelled` exception is raised in this scope @@ -117,17 +146,18 @@ def fail_at(deadline: float) -> AbstractContextManager[trio.CancelScope]: # typ ValueError: if deadline is NaN. """ - with move_on_at(deadline) as scope: + with move_on_at(deadline, shield=shield) as scope: yield scope if scope.cancelled_caught: raise TooSlowError -if not TYPE_CHECKING: - fail_at = contextmanager(fail_at) - - -def fail_after(seconds: float) -> AbstractContextManager[trio.CancelScope]: +@contextmanager +def fail_after( + seconds: float, + *, + shield: bool = False, +) -> Generator[trio.CancelScope, None, None]: """Creates a cancel scope with the given timeout, and raises an error if it is actually cancelled. @@ -138,8 +168,12 @@ def fail_after(seconds: float) -> AbstractContextManager[trio.CancelScope]: it's caught and discarded. When it reaches :func:`fail_after`, then it's caught and :exc:`TooSlowError` is raised in its place. + The deadline of the cancel scope is calculated upon entering. + Args: seconds (float): The timeout. + shield (bool): Initial value for the `~trio.CancelScope.shield` attribute + of the newly created cancel scope. Raises: TooSlowError: if a :exc:`Cancelled` exception is raised in this scope @@ -147,6 +181,17 @@ def fail_after(seconds: float) -> AbstractContextManager[trio.CancelScope]: ValueError: if *seconds* is less than zero or NaN. """ - if seconds < 0: - raise ValueError("timeout must be non-negative") - return fail_at(trio.current_time() + seconds) + with move_on_after(seconds, shield=shield) as scope: + yield scope + if scope.cancelled_caught: + raise TooSlowError + + +# Users don't need to know that fail_at & fail_after wraps move_on_at and move_on_after +# and there is no functional difference. So we replace the return value when generating +# documentation. +if "sphinx" in sys.modules: # pragma: no cover + import inspect + + for c in (fail_at, fail_after): + c.__signature__ = inspect.Signature.from_callable(c).replace(return_annotation=trio.CancelScope) # type: ignore[union-attr] diff --git a/src/trio/_tools/gen_exports.py b/src/trio/_tools/gen_exports.py index 51cb69f2a9..ae6b0293e8 100755 --- a/src/trio/_tools/gen_exports.py +++ b/src/trio/_tools/gen_exports.py @@ -34,12 +34,11 @@ import sys -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT """ -TEMPLATE = """sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True -try: +TEMPLATE = """try: return{}GLOBAL_RUN_CONTEXT.{}.{} except AttributeError: raise RuntimeError("must be called from async context") from None @@ -181,17 +180,13 @@ def run_linters(file: File, source: str) -> str: SystemExit: If either failed. """ - success, response = run_black(file, source) - if not success: - print(response) - sys.exit(1) - - success, response = run_ruff(file, response) - if not success: # pragma: no cover # Test for run_ruff should catch - print(response) - sys.exit(1) + for fn in (run_black, run_ruff): + success, source = fn(file, source) + if not success: + print(source) + sys.exit(1) - return response + return source def gen_public_wrappers_source(file: File) -> str: @@ -200,9 +195,7 @@ def gen_public_wrappers_source(file: File) -> str: """ header = [HEADER] - - if file.imports: - header.append(file.imports) + header.append(file.imports) if file.platform: # Simple checks to avoid repeating imports. If this messes up, type checkers/tests will # just give errors. @@ -211,7 +204,7 @@ def gen_public_wrappers_source(file: File) -> str: if "import sys" not in file.imports: # pragma: no cover header.append("import sys\n") header.append( - f'\nassert not TYPE_CHECKING or sys.platform=="{file.platform}"\n' + f'\nassert not TYPE_CHECKING or sys.platform=="{file.platform}"\n', ) generated = ["".join(header)] @@ -232,7 +225,7 @@ def gen_public_wrappers_source(file: File) -> str: is_cm = False # Remove decorators - method.decorator_list = [] + method.decorator_list = [ast.Name("enable_ki_protection")] # Create pass through arguments new_args = create_passthrough_args(method) @@ -248,7 +241,7 @@ def gen_public_wrappers_source(file: File) -> str: func = astor.to_source(method, indent_with=" " * 4) if is_cm: # pragma: no cover - func = func.replace("->Iterator", "->ContextManager") + func = func.replace("->Iterator", "->AbstractContextManager") # Create export function body template = TEMPLATE.format( @@ -273,8 +266,7 @@ def matches_disk_files(new_files: dict[str, str]) -> bool: for new_path, new_source in new_files.items(): if not os.path.exists(new_path): return False - with open(new_path, encoding="utf-8") as old_file: - old_source = old_file.read() + old_source = Path(new_path).read_text(encoding="utf-8") if old_source != new_source: return False return True @@ -289,27 +281,34 @@ def process(files: Iterable[File], *, do_test: bool) -> None: dirname, basename = os.path.split(file.path) new_path = os.path.join(dirname, PREFIX + basename) new_files[new_path] = new_source + matches_disk = matches_disk_files(new_files) if do_test: - if not matches_disk_files(new_files): + if not matches_disk: print("Generated sources are outdated. Please regenerate.") sys.exit(1) else: print("Generated sources are up to date.") else: for new_path, new_source in new_files.items(): - with open(new_path, "w", encoding="utf-8", newline="\n") as f: - f.write(new_source) + with open(new_path, "w", encoding="utf-8", newline="\n") as fp: + fp.write(new_source) print("Regenerated sources successfully.") + if not matches_disk: # TODO: test this branch + # With pre-commit integration, show that we edited files. + sys.exit(1) # This is in fact run in CI, but only in the formatting check job, which # doesn't collect coverage. def main() -> None: # pragma: no cover parser = argparse.ArgumentParser( - description="Generate python code for public api wrappers" + description="Generate python code for public api wrappers", ) parser.add_argument( - "--test", "-t", action="store_true", help="test if code is still up to date" + "--test", + "-t", + action="store_true", + help="test if code is still up to date", ) parsed_args = parser.parse_args() @@ -374,25 +373,30 @@ def main() -> None: # pragma: no cover """ IMPORTS_KQUEUE = """\ -from typing import Callable, ContextManager, TYPE_CHECKING +from typing import TYPE_CHECKING if TYPE_CHECKING: import select + from collections.abc import Callable + from contextlib import AbstractContextManager from .._channel import MemoryReceiveChannel + from .. import _core from .._file_io import _HasFileNo from ._traps import Abort, RaiseCancelT """ IMPORTS_WINDOWS = """\ -from typing import TYPE_CHECKING, ContextManager +from typing import TYPE_CHECKING if TYPE_CHECKING: - from .._file_io import _HasFileNo - from ._windows_cffi import Handle, CData + from contextlib import AbstractContextManager + from typing_extensions import Buffer from .._channel import MemoryReceiveChannel + from .._file_io import _HasFileNo + from ._windows_cffi import Handle, CData """ diff --git a/src/trio/_tools/mypy_annotate.py b/src/trio/_tools/mypy_annotate.py index 6bd20f401c..5acb9b993c 100644 --- a/src/trio/_tools/mypy_annotate.py +++ b/src/trio/_tools/mypy_annotate.py @@ -86,7 +86,9 @@ def main(argv: list[str]) -> None: """Look for error messages, and convert the format.""" parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--dumpfile", help="File to write pickled messages to.", required=True + "--dumpfile", + help="File to write pickled messages to.", + required=True, ) parser.add_argument( "--platform", diff --git a/src/trio/_unix_pipes.py b/src/trio/_unix_pipes.py index 34340d2b36..a95f761bcc 100644 --- a/src/trio/_unix_pipes.py +++ b/src/trio/_unix_pipes.py @@ -121,10 +121,10 @@ class FdStream(Stream): def __init__(self, fd: int) -> None: self._fd_holder = _FdHolder(fd) self._send_conflict_detector = ConflictDetector( - "another task is using this stream for send" + "another task is using this stream for send", ) self._receive_conflict_detector = ConflictDetector( - "another task is using this stream for receive" + "another task is using this stream for receive", ) async def send_all(self, data: bytes) -> None: @@ -147,7 +147,7 @@ async def send_all(self, data: bytes) -> None: except OSError as e: if e.errno == errno.EBADF: raise trio.ClosedResourceError( - "file was already closed" + "file was already closed", ) from None else: raise trio.BrokenResourceError from e @@ -182,7 +182,7 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes: except OSError as exc: if exc.errno == errno.EBADF: raise trio.ClosedResourceError( - "file was already closed" + "file was already closed", ) from None else: raise trio.BrokenResourceError from exc diff --git a/src/trio/_util.py b/src/trio/_util.py index 7c9e194d19..994a4655b2 100644 --- a/src/trio/_util.py +++ b/src/trio/_util.py @@ -3,19 +3,15 @@ import collections.abc import inspect -import os import signal -import threading from abc import ABCMeta +from collections.abc import Awaitable, Callable, Sequence from functools import update_wrapper from typing import ( TYPE_CHECKING, Any, - Awaitable, - Callable, Generic, NoReturn, - Sequence, TypeVar, final as std_final, ) @@ -24,7 +20,8 @@ import trio -CallT = TypeVar("CallT", bound=Callable[..., Any]) +# Explicit "Any" is not allowed +CallT = TypeVar("CallT", bound=Callable[..., Any]) # type: ignore[misc] T = TypeVar("T") RetT = TypeVar("RetT") @@ -37,60 +34,6 @@ PosArgsT = TypeVarTuple("PosArgsT") -if TYPE_CHECKING: - # Don't type check the implementation below, pthread_kill does not exist on Windows. - def signal_raise(signum: int) -> None: ... - - -# Equivalent to the C function raise(), which Python doesn't wrap -elif os.name == "nt": - # On Windows, os.kill exists but is really weird. - # - # If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver - # those using GenerateConsoleCtrlEvent. But I found that when I tried - # to run my test normally, it would freeze waiting... unless I added - # print statements, in which case the test suddenly worked. So I guess - # these signals are only delivered if/when you access the console? I - # don't really know what was going on there. From reading the - # GenerateConsoleCtrlEvent docs I don't know how it worked at all. - # - # I later spent a bunch of time trying to make GenerateConsoleCtrlEvent - # work for creating synthetic control-C events, and... failed - # utterly. There are lots of details in the code and comments - # removed/added at this commit: - # https://github.com/python-trio/trio/commit/95843654173e3e826c34d70a90b369ba6edf2c23 - # - # OTOH, if you pass os.kill any *other* signal number... then CPython - # just calls TerminateProcess (wtf). - # - # So, anyway, os.kill is not so useful for testing purposes. Instead, - # we use raise(): - # - # https://msdn.microsoft.com/en-us/library/dwwzkt4c.aspx - # - # Have to import cffi inside the 'if os.name' block because we don't - # depend on cffi on non-Windows platforms. (It would be easy to switch - # this to ctypes though if we ever remove the cffi dependency.) - # - # Some more information: - # https://bugs.python.org/issue26350 - # - # Anyway, we use this for two things: - # - redelivering unhandled signals - # - generating synthetic signals for tests - # and for both of those purposes, 'raise' works fine. - import cffi - - _ffi = cffi.FFI() - _ffi.cdef("int raise(int);") - _lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll") - signal_raise = getattr(_lib, "raise") -else: - - def signal_raise(signum: int) -> None: - signal.pthread_kill(threading.get_ident(), signum) - - # See: #461 as to why this is needed. # The gist is that threading.main_thread() has the capability to lie to us # if somebody else edits the threading ident cache to replace the main @@ -154,7 +97,7 @@ def _return_value_looks_like_wrong_library(value: object) -> bool: "Instead, you want (notice the parentheses!):\n" "\n" f" trio.run({async_fn.__name__}, ...) # correct!\n" - f" nursery.start_soon({async_fn.__name__}, ...) # correct!" + f" nursery.start_soon({async_fn.__name__}, ...) # correct!", ) from None # Give good error for: nursery.start_soon(future) @@ -163,7 +106,7 @@ def _return_value_looks_like_wrong_library(value: object) -> bool: "Trio was expecting an async function, but instead it got " f"{async_fn!r} โ€“ are you trying to use a library written for " "asyncio/twisted/tornado or similar? That won't work " - "without some sort of compatibility shim." + "without some sort of compatibility shim.", ) from None raise @@ -183,19 +126,19 @@ def _return_value_looks_like_wrong_library(value: object) -> bool: raise TypeError( f"Trio got unexpected {coro!r} โ€“ are you trying to use a " "library written for asyncio/twisted/tornado or similar? " - "That won't work without some sort of compatibility shim." + "That won't work without some sort of compatibility shim.", ) if inspect.isasyncgen(coro): raise TypeError( "start_soon expected an async function but got an async " - f"generator {coro!r}" + f"generator {coro!r}", ) # Give good error for: nursery.start_soon(some_sync_fn) raise TypeError( "Trio expected an async function, but {!r} appears to be " - "synchronous".format(getattr(async_fn, "__qualname__", async_fn)) + "synchronous".format(getattr(async_fn, "__qualname__", async_fn)), ) return coro @@ -234,16 +177,18 @@ def __exit__( self._held = False -def async_wraps( +# Explicit "Any" is not allowed +def async_wraps( # type: ignore[misc] cls: type[object], wrapped_cls: type[object], attr_name: str, ) -> Callable[[CallT], CallT]: """Similar to wraps, but for async wrappers of non-async functions.""" - def decorator(func: CallT) -> CallT: + # Explicit "Any" is not allowed + def decorator(func: CallT) -> CallT: # type: ignore[misc] func.__name__ = attr_name - func.__qualname__ = ".".join((cls.__qualname__, attr_name)) + func.__qualname__ = f"{cls.__qualname__}.{attr_name}" func.__doc__ = f"Like :meth:`~{wrapped_cls.__module__}.{wrapped_cls.__qualname__}.{attr_name}`, but async." @@ -253,7 +198,8 @@ def decorator(func: CallT) -> CallT: def fixup_module_metadata( - module_name: str, namespace: collections.abc.Mapping[str, object] + module_name: str, + namespace: collections.abc.Mapping[str, object], ) -> None: seen_ids: set[int] = set() @@ -303,11 +249,15 @@ def open_memory_channel(max_buffer_size: int) -> Tuple[ but at least it becomes possible to write those. """ - def __init__(self, fn: Callable[..., RetT]) -> None: + # Explicit .../"Any" is not allowed + def __init__( # type: ignore[misc] + self, + fn: Callable[..., RetT], + ) -> None: update_wrapper(self, fn) self._fn = fn - def __call__(self, *args: Any, **kwargs: Any) -> RetT: + def __call__(self, *args: object, **kwargs: object) -> RetT: return self._fn(*args, **kwargs) def __getitem__(self, subscript: object) -> Self: @@ -370,7 +320,7 @@ class SomeClass(metaclass=NoPublicConstructor): def __call__(cls, *args: object, **kwargs: object) -> None: raise TypeError( - f"{cls.__module__}.{cls.__qualname__} has no public constructor" + f"{cls.__module__}.{cls.__qualname__} has no public constructor", ) def _create(cls: type[T], *args: object, **kwargs: object) -> T: @@ -396,9 +346,11 @@ def name_asyncgen(agen: AsyncGeneratorType[object, NoReturn]) -> str: # work around a pyright error if TYPE_CHECKING: - Fn = TypeVar("Fn", bound=Callable[..., object]) + # Explicit .../"Any" is not allowed + Fn = TypeVar("Fn", bound=Callable[..., object]) # type: ignore[misc] - def wraps( + # Explicit .../"Any" is not allowed + def wraps( # type: ignore[misc] wrapped: Callable[..., object], assigned: Sequence[str] = ..., updated: Sequence[str] = ..., diff --git a/src/trio/_version.py b/src/trio/_version.py index b777fa4efe..0ff89b6f87 100644 --- a/src/trio/_version.py +++ b/src/trio/_version.py @@ -1,3 +1,3 @@ # This file is imported from __init__.py and parsed by setuptools -__version__ = "0.26.0+dev" +__version__ = "0.27.0+dev" diff --git a/src/trio/_windows_pipes.py b/src/trio/_windows_pipes.py index 43592807b8..e1eea1e72d 100644 --- a/src/trio/_windows_pipes.py +++ b/src/trio/_windows_pipes.py @@ -49,7 +49,7 @@ class PipeSendStream(SendStream): def __init__(self, handle: int) -> None: self._handle_holder = _HandleHolder(handle) self._conflict_detector = ConflictDetector( - "another task is currently using this pipe" + "another task is currently using this pipe", ) async def send_all(self, data: bytes) -> None: @@ -93,7 +93,7 @@ class PipeReceiveStream(ReceiveStream): def __init__(self, handle: int) -> None: self._handle_holder = _HandleHolder(handle) self._conflict_detector = ConflictDetector( - "another task is currently using this pipe" + "another task is currently using this pipe", ) async def receive_some(self, max_bytes: int | None = None) -> bytes: @@ -112,12 +112,13 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes: buffer = bytearray(max_bytes) try: size = await _core.readinto_overlapped( - self._handle_holder.handle, buffer + self._handle_holder.handle, + buffer, ) except BrokenPipeError: if self._handle_holder.closed: raise _core.ClosedResourceError( - "another task closed this pipe" + "another task closed this pipe", ) from None # Windows raises BrokenPipeError on one end of a pipe diff --git a/src/trio/lowlevel.py b/src/trio/lowlevel.py index 711e0d1dca..2a869652ad 100644 --- a/src/trio/lowlevel.py +++ b/src/trio/lowlevel.py @@ -23,6 +23,7 @@ Task as Task, TrioToken as TrioToken, add_instrument as add_instrument, + add_parking_lot_breaker as add_parking_lot_breaker, cancel_shielded_checkpoint as cancel_shielded_checkpoint, checkpoint as checkpoint, checkpoint_if_cancelled as checkpoint_if_cancelled, @@ -38,6 +39,7 @@ permanently_detach_coroutine_object as permanently_detach_coroutine_object, reattach_detached_coroutine_object as reattach_detached_coroutine_object, remove_instrument as remove_instrument, + remove_parking_lot_breaker as remove_parking_lot_breaker, reschedule as reschedule, spawn_system_task as spawn_system_task, start_guest_run as start_guest_run, diff --git a/src/trio/socket.py b/src/trio/socket.py index e38501fb60..617f0382c0 100644 --- a/src/trio/socket.py +++ b/src/trio/socket.py @@ -29,7 +29,7 @@ _name: getattr(_stdlib_socket, _name) for _name in _stdlib_socket.__all__ # type: ignore if _name.isupper() and _name not in _bad_symbols - } + }, ) # import the overwrites diff --git a/src/trio/testing/_check_streams.py b/src/trio/testing/_check_streams.py index c54c99c1fe..e58e2ddfed 100644 --- a/src/trio/testing/_check_streams.py +++ b/src/trio/testing/_check_streams.py @@ -3,14 +3,11 @@ import random import sys +from collections.abc import Awaitable, Callable, Generator from contextlib import contextmanager, suppress from typing import ( TYPE_CHECKING, - Awaitable, - Callable, - Generator, Generic, - Tuple, TypeVar, ) @@ -31,7 +28,7 @@ Res1 = TypeVar("Res1", bound=AsyncResource) Res2 = TypeVar("Res2", bound=AsyncResource) -StreamMaker: TypeAlias = Callable[[], Awaitable[Tuple[Res1, Res2]]] +StreamMaker: TypeAlias = Callable[[], Awaitable[tuple[Res1, Res2]]] class _ForceCloseBoth(Generic[Res1, Res2]): @@ -57,7 +54,8 @@ async def __aexit__( # on pytest, as the check_* functions are publicly exported. @contextmanager def _assert_raises( - expected_exc: type[BaseException], wrapped: bool = False + expected_exc: type[BaseException], + wrapped: bool = False, ) -> Generator[None, None, None]: __tracebackhide__ = True try: @@ -174,7 +172,8 @@ async def simple_check_wait_send_all_might_not_block( async with _core.open_nursery() as nursery: nursery.start_soon( - simple_check_wait_send_all_might_not_block, nursery.cancel_scope + simple_check_wait_send_all_might_not_block, + nursery.cancel_scope, ) nursery.start_soon(do_receive_some, 1) @@ -312,7 +311,7 @@ async def expect_cancelled( # receive stream causes it to wake up. async with _ForceCloseBoth(await stream_maker()) as (s, r): - async def receive_expecting_closed(): + async def receive_expecting_closed() -> None: with _assert_raises(_core.ClosedResourceError): await r.receive_some(10) @@ -467,7 +466,9 @@ async def flipped_clogged_stream_maker() -> tuple[Stream, Stream]: test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little") async def sender( - s: Stream, data: bytes | bytearray | memoryview, seed: int + s: Stream, + data: bytes | bytearray | memoryview, + seed: int, ) -> None: r = random.Random(seed) m = memoryview(data) diff --git a/src/trio/testing/_fake_net.py b/src/trio/testing/_fake_net.py index f8589f3a9c..2f5bd624ae 100644 --- a/src/trio/testing/_fake_net.py +++ b/src/trio/testing/_fake_net.py @@ -17,7 +17,6 @@ from typing import ( TYPE_CHECKING, Any, - Iterable, NoReturn, TypeVar, Union, @@ -31,11 +30,14 @@ if TYPE_CHECKING: import builtins + from collections.abc import Iterable from socket import AddressFamily, SocketKind from types import TracebackType from typing_extensions import Buffer, Self, TypeAlias + from trio._socket import AddressFormat + IPAddress: TypeAlias = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] @@ -99,7 +101,8 @@ def as_python_sockaddr(self) -> tuple[str, int] | tuple[str, int, int, int]: @classmethod def from_python_sockaddr( - cls: type[T_UDPEndpoint], sockaddr: tuple[str, int] | tuple[str, int, int, int] + cls: type[T_UDPEndpoint], + sockaddr: tuple[str, int] | tuple[str, int, int, int], ) -> T_UDPEndpoint: ip, port = sockaddr[:2] return cls(ip=ipaddress.ip_address(ip), port=port) @@ -120,7 +123,9 @@ class UDPPacket: # not used/tested anywhere def reply(self, payload: bytes) -> UDPPacket: # pragma: no cover return UDPPacket( - source=self.destination, destination=self.source, payload=payload + source=self.destination, + destination=self.source, + payload=payload, ) @@ -156,7 +161,9 @@ async def getaddrinfo( raise NotImplementedError("FakeNet doesn't do fake DNS yet") async def getnameinfo( - self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + self, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, ) -> tuple[str, str]: raise NotImplementedError("FakeNet doesn't do fake DNS yet") @@ -205,7 +212,7 @@ def __init__( family: AddressFamily, type: SocketKind, proto: int, - ): + ) -> None: self._fake_net = fake_net if not family: # pragma: no cover @@ -256,7 +263,10 @@ def close(self) -> None: self._packet_receiver.close() async def _resolve_address_nocp( - self, address: object, *, local: bool + self, + address: object, + *, + local: bool, ) -> tuple[str, int]: return await trio._socket._resolve_address_nocp( # type: ignore[no-any-return] self.type, @@ -305,7 +315,7 @@ async def _sendmsg( buffers: Iterable[Buffer], ancdata: Iterable[tuple[int, int, Buffer]] = (), flags: int = 0, - address: Any | None = None, + address: AddressFormat | None = None, ) -> int: self._check_closed() @@ -349,7 +359,12 @@ async def _recvmsg_into( buffers: Iterable[Buffer], ancbufsize: int = 0, flags: int = 0, - ) -> tuple[int, list[tuple[int, int, bytes]], int, Any]: + ) -> tuple[ + int, + list[tuple[int, int, bytes]], + int, + tuple[str, int] | tuple[str, int, int, int], + ]: if ancbufsize != 0: raise NotImplementedError("FakeNet doesn't support ancillary data") if flags != 0: @@ -360,7 +375,7 @@ async def _recvmsg_into( raise NotImplementedError( "The code will most likely hang if you try to receive on a fakesocket " "without a binding. If that is not the case, or you explicitly want to " - "test that, remove this warning." + "test that, remove this warning.", ) self._check_closed() @@ -399,11 +414,13 @@ def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]: self._check_closed() if self._binding is not None: assert hasattr( - self._binding, "remote" + self._binding, + "remote", ), "This method seems to assume that self._binding has a remote UDPEndpoint" if self._binding.remote is not None: # pragma: no cover assert isinstance( - self._binding.remote, UDPEndpoint + self._binding.remote, + UDPEndpoint, ), "Self._binding.remote should be a UDPEndpoint" return self._binding.remote.as_python_sockaddr() _fake_err(errno.ENOTCONN) @@ -415,7 +432,11 @@ def getsockopt(self, /, level: int, optname: int) -> int: ... def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ... def getsockopt( - self, /, level: int, optname: int, buflen: int | None = None + self, + /, + level: int, + optname: int, + buflen: int | None = None, ) -> int | bytes: self._check_closed() raise OSError(f"FakeNet doesn't implement getsockopt({level}, {optname})") @@ -425,7 +446,12 @@ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: @overload def setsockopt( - self, /, level: int, optname: int, value: None, optlen: int + self, + /, + level: int, + optname: int, + value: None, + optlen: int, ) -> None: ... def setsockopt( @@ -464,20 +490,28 @@ def __exit__( async def send(self, data: Buffer, flags: int = 0) -> int: return await self.sendto(data, flags, None) + # __ prefixed arguments because typeshed uses that and typechecker issues @overload async def sendto( - self, __data: Buffer, __address: tuple[object, ...] | str | Buffer + self, + __data: Buffer, # noqa: PYI063 + __address: tuple[object, ...] | str | Buffer, ) -> int: ... + # __ prefixed arguments because typeshed uses that and typechecker issues @overload async def sendto( self, - __data: Buffer, + __data: Buffer, # noqa: PYI063 __flags: int, - __address: tuple[object, ...] | str | None | Buffer, + __address: tuple[object, ...] | str | Buffer | None, ) -> int: ... - async def sendto(self, *args: Any) -> int: + # Explicit "Any" is not allowed + async def sendto( # type: ignore[misc] + self, + *args: Any, + ) -> int: data: Buffer flags: int address: tuple[object, ...] | str | Buffer @@ -491,33 +525,47 @@ async def sendto(self, *args: Any) -> int: return await self._sendmsg([data], [], flags, address) async def recv(self, bufsize: int, flags: int = 0) -> bytes: - data, address = await self.recvfrom(bufsize, flags) + data, _address = await self.recvfrom(bufsize, flags) return data async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int: - got_bytes, address = await self.recvfrom_into(buf, nbytes, flags) + got_bytes, _address = await self.recvfrom_into(buf, nbytes, flags) return got_bytes - async def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[bytes, Any]: - data, ancdata, msg_flags, address = await self._recvmsg(bufsize, flags) + async def recvfrom( + self, + bufsize: int, + flags: int = 0, + ) -> tuple[bytes, AddressFormat]: + data, _ancdata, _msg_flags, address = await self._recvmsg(bufsize, flags) return data, address async def recvfrom_into( - self, buf: Buffer, nbytes: int = 0, flags: int = 0 - ) -> tuple[int, Any]: + self, + buf: Buffer, + nbytes: int = 0, + flags: int = 0, + ) -> tuple[int, AddressFormat]: if nbytes != 0 and nbytes != memoryview(buf).nbytes: raise NotImplementedError("partial recvfrom_into") - got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into( - [buf], 0, flags + got_nbytes, _ancdata, _msg_flags, address = await self._recvmsg_into( + [buf], + 0, + flags, ) return got_nbytes, address async def _recvmsg( - self, bufsize: int, ancbufsize: int = 0, flags: int = 0 - ) -> tuple[bytes, list[tuple[int, int, bytes]], int, Any]: + self, + bufsize: int, + ancbufsize: int = 0, + flags: int = 0, + ) -> tuple[bytes, list[tuple[int, int, bytes]], int, AddressFormat]: buf = bytearray(bufsize) got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into( - [buf], ancbufsize, flags + [buf], + ancbufsize, + flags, ) return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address) diff --git a/src/trio/testing/_memory_streams.py b/src/trio/testing/_memory_streams.py index c9d430a9e6..6dd48ebf3d 100644 --- a/src/trio/testing/_memory_streams.py +++ b/src/trio/testing/_memory_streams.py @@ -1,7 +1,8 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, TypeVar from .. import _core, _util from .._highlevel_generic import StapledStream @@ -29,7 +30,7 @@ def __init__(self) -> None: self._closed = False self._lot = _core.ParkingLot() self._fetch_lock = _util.ConflictDetector( - "another task is already fetching data" + "another task is already fetching data", ) # This object treats "close" as being like closing the send side of a @@ -113,9 +114,9 @@ def __init__( send_all_hook: AsyncHook | None = None, wait_send_all_might_not_block_hook: AsyncHook | None = None, close_hook: SyncHook | None = None, - ): + ) -> None: self._conflict_detector = _util.ConflictDetector( - "another task is using this stream" + "another task is using this stream", ) self._outgoing = _UnboundedByteQueue() self.send_all_hook = send_all_hook @@ -223,9 +224,9 @@ def __init__( self, receive_some_hook: AsyncHook | None = None, close_hook: SyncHook | None = None, - ): + ) -> None: self._conflict_detector = _util.ConflictDetector( - "another task is using this stream" + "another task is using this stream", ) self._incoming = _UnboundedByteQueue() self._closed = False @@ -347,7 +348,8 @@ def memory_stream_one_way_pair() -> tuple[MemorySendStream, MemoryReceiveStream] def pump_from_send_stream_to_recv_stream() -> None: memory_stream_pump(send_stream, recv_stream) - async def async_pump_from_send_stream_to_recv_stream() -> None: + # await not used + async def async_pump_from_send_stream_to_recv_stream() -> None: # noqa: RUF029 pump_from_send_stream_to_recv_stream() send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream @@ -356,7 +358,7 @@ async def async_pump_from_send_stream_to_recv_stream() -> None: def _make_stapled_pair( - one_way_pair: Callable[[], tuple[SendStreamT, ReceiveStreamT]] + one_way_pair: Callable[[], tuple[SendStreamT, ReceiveStreamT]], ) -> tuple[ StapledStream[SendStreamT, ReceiveStreamT], StapledStream[SendStreamT, ReceiveStreamT], @@ -461,10 +463,10 @@ def __init__(self) -> None: self._receiver_waiting = False self._waiters = _core.ParkingLot() self._send_conflict_detector = _util.ConflictDetector( - "another task is already sending" + "another task is already sending", ) self._receive_conflict_detector = _util.ConflictDetector( - "another task is already receiving" + "another task is already receiving", ) def _something_happened(self) -> None: @@ -548,7 +550,7 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: class _LockstepSendStream(SendStream): - def __init__(self, lbq: _LockstepByteQueue): + def __init__(self, lbq: _LockstepByteQueue) -> None: self._lbq = lbq def close(self) -> None: @@ -566,7 +568,7 @@ async def wait_send_all_might_not_block(self) -> None: class _LockstepReceiveStream(ReceiveStream): - def __init__(self, lbq: _LockstepByteQueue): + def __init__(self, lbq: _LockstepByteQueue) -> None: self._lbq = lbq def close(self) -> None: diff --git a/src/trio/testing/_raises_group.py b/src/trio/testing/_raises_group.py index f96dcb2351..700c16ca6a 100644 --- a/src/trio/testing/_raises_group.py +++ b/src/trio/testing/_raises_group.py @@ -2,19 +2,15 @@ import re import sys +from re import Pattern from typing import ( TYPE_CHECKING, - Callable, - ContextManager, Generic, Literal, - Pattern, - Sequence, cast, overload, ) -from trio._deprecate import warn_deprecated from trio._util import final if TYPE_CHECKING: @@ -23,23 +19,32 @@ # sphinx will *only* work if we use types.TracebackType, and import # *inside* TYPE_CHECKING. No other combination works..... import types + from collections.abc import Callable, Sequence from _pytest._code.code import ExceptionChainRepr, ReprExceptionInfo, Traceback from typing_extensions import TypeGuard, TypeVar + # this conditional definition is because we want to allow a TypeVar default MatchE = TypeVar( - "MatchE", bound=BaseException, default=BaseException, covariant=True + "MatchE", + bound=BaseException, + default=BaseException, + covariant=True, ) else: from typing import TypeVar MatchE = TypeVar("MatchE", bound=BaseException, covariant=True) + # RaisesGroup doesn't work with a default. -E = TypeVar("E", bound=BaseException, covariant=True) -# These two typevars are special cased in sphinx config to workaround lookup bugs. +BaseExcT_co = TypeVar("BaseExcT_co", bound=BaseException, covariant=True) +BaseExcT_1 = TypeVar("BaseExcT_1", bound=BaseException) +BaseExcT_2 = TypeVar("BaseExcT_2", bound=BaseException) +ExcT_1 = TypeVar("ExcT_1", bound=Exception) +ExcT_2 = TypeVar("ExcT_2", bound=Exception) if sys.version_info < (3, 11): - from exceptiongroup import BaseExceptionGroup + from exceptiongroup import BaseExceptionGroup, ExceptionGroup @final @@ -49,12 +54,14 @@ class _ExceptionInfo(Generic[MatchE]): _excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None def __init__( - self, excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None - ): + self, + excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None, + ) -> None: self._excinfo = excinfo def fill_unfilled( - self, exc_info: tuple[type[MatchE], MatchE, types.TracebackType] + self, + exc_info: tuple[type[MatchE], MatchE, types.TracebackType], ) -> None: """Fill an unfilled ExceptionInfo created with ``for_later()``.""" assert self._excinfo is None, "ExceptionInfo was already filled" @@ -92,7 +99,7 @@ def tb(self) -> types.TracebackType: def exconly(self, tryshort: bool = False) -> str: raise NotImplementedError( - "This is a helper method only available if you use RaisesGroup with the pytest package installed" + "This is a helper method only available if you use RaisesGroup with the pytest package installed", ) def errisinstance( @@ -100,7 +107,7 @@ def errisinstance( exc: builtins.type[BaseException] | tuple[builtins.type[BaseException], ...], ) -> bool: raise NotImplementedError( - "This is a helper method only available if you use RaisesGroup with the pytest package installed" + "This is a helper method only available if you use RaisesGroup with the pytest package installed", ) def getrepr( @@ -114,7 +121,7 @@ def getrepr( chain: bool = True, ) -> ReprExceptionInfo | ExceptionChainRepr: raise NotImplementedError( - "This is a helper method only available if you use RaisesGroup with the pytest package installed" + "This is a helper method only available if you use RaisesGroup with the pytest package installed", ) @@ -139,12 +146,12 @@ def _stringify_exception(exc: BaseException) -> str: [ getattr(exc, "message", str(exc)), *getattr(exc, "__notes__", []), - ] + ], ) # String patterns default to including the unicode flag. -_regex_no_flags = re.compile("").flags +_REGEX_NO_FLAGS = re.compile(r"").flags @final @@ -171,7 +178,7 @@ def __init__( exception_type: type[MatchE], match: str | Pattern[str] = ..., check: Callable[[MatchE], bool] = ..., - ): ... + ) -> None: ... @overload def __init__( @@ -180,10 +187,10 @@ def __init__( match: str | Pattern[str], # If exception_type is not provided, check() must do any typechecks itself. check: Callable[[BaseException], bool] = ..., - ): ... + ) -> None: ... @overload - def __init__(self, *, check: Callable[[BaseException], bool]): ... + def __init__(self, *, check: Callable[[BaseException], bool]) -> None: ... def __init__( self, @@ -195,7 +202,7 @@ def __init__( raise ValueError("You must specify at least one parameter to match on.") if exception_type is not None and not issubclass(exception_type, BaseException): raise ValueError( - f"exception_type {exception_type} must be a subclass of BaseException" + f"exception_type {exception_type} must be a subclass of BaseException", ) self.exception_type = exception_type self.match: Pattern[str] | None @@ -224,16 +231,18 @@ def matches(self, exception: BaseException) -> TypeGuard[MatchE]: """ if self.exception_type is not None and not isinstance( - exception, self.exception_type + exception, + self.exception_type, ): return False if self.match is not None and not re.search( - self.match, _stringify_exception(exception) + self.match, + _stringify_exception(exception), ): return False # If exception_type is None check() accepts BaseException. # If non-none, we have done an isinstance check above. - return self.check is None or self.check(cast(MatchE, exception)) + return self.check is None or self.check(cast("MatchE", exception)) def __str__(self) -> str: reqs = [] @@ -242,36 +251,15 @@ def __str__(self) -> str: if (match := self.match) is not None: # If no flags were specified, discard the redundant re.compile() here. reqs.append( - f"match={match.pattern if match.flags == _regex_no_flags else match!r}" + f"match={match.pattern if match.flags == _REGEX_NO_FLAGS else match!r}", ) if self.check is not None: reqs.append(f"check={self.check!r}") return f'Matcher({", ".join(reqs)})' -# typing this has been somewhat of a nightmare, with the primary difficulty making -# the return type of __enter__ correct. Ideally it would function like this -# with RaisesGroup(RaisesGroup(ValueError)) as excinfo: -# ... -# assert_type(excinfo.value, ExceptionGroup[ExceptionGroup[ValueError]]) -# in addition to all the simple cases, but getting all the way to the above seems maybe -# impossible. The type being RaisesGroup[RaisesGroup[ValueError]] is probably also fine, -# as long as I add fake properties corresponding to the properties of exceptiongroup. But -# I had trouble with it handling recursive cases properly. - -# Current solution settles on the above giving BaseExceptionGroup[RaisesGroup[ValueError]], and it not -# being a type error to do `with RaisesGroup(ValueError()): ...` - but that will error on runtime. - -# We lie to type checkers that we inherit, so excinfo.value and sub-exceptiongroups can be treated as ExceptionGroups -if TYPE_CHECKING: - SuperClass = BaseExceptionGroup -else: - # At runtime, use a redundant Generic base class which effectively gets ignored. - SuperClass = Generic - - @final -class RaisesGroup(ContextManager[ExceptionInfo[BaseExceptionGroup[E]]], SuperClass[E]): +class RaisesGroup(Generic[BaseExcT_co]): """Contextmanager for checking for an expected `ExceptionGroup`. This works similar to ``pytest.raises``, and a version of it will hopefully be added upstream, after which this can be deprecated and removed. See https://github.com/pytest-dev/pytest/issues/11538 @@ -324,62 +312,121 @@ class RaisesGroup(ContextManager[ExceptionInfo[BaseExceptionGroup[E]]], SuperCla even though it generally does not care about the order of the exceptions in the group. To avoid the above you should specify the first ValueError with a Matcher as well. - - It is also not typechecked perfectly, and that's likely not possible with the current approach. Most common usage should work without issue though. """ - # needed for pyright, since BaseExceptionGroup.__new__ takes two arguments - if TYPE_CHECKING: - - def __new__(cls, *args: object, **kwargs: object) -> RaisesGroup[E]: ... - # allow_unwrapped=True requires: singular exception, exception not being # RaisesGroup instance, match is None, check is None @overload def __init__( self, - exception: type[E] | Matcher[E], + exception: type[BaseExcT_co] | Matcher[BaseExcT_co], *, allow_unwrapped: Literal[True], flatten_subgroups: bool = False, - match: None = None, - check: None = None, - ): ... + ) -> None: ... # flatten_subgroups = True also requires no nested RaisesGroup @overload def __init__( self, - exception: type[E] | Matcher[E], - *other_exceptions: type[E] | Matcher[E], - allow_unwrapped: Literal[False] = False, + exception: type[BaseExcT_co] | Matcher[BaseExcT_co], + *other_exceptions: type[BaseExcT_co] | Matcher[BaseExcT_co], flatten_subgroups: Literal[True], match: str | Pattern[str] | None = None, - check: Callable[[BaseExceptionGroup[E]], bool] | None = None, - ): ... + check: Callable[[BaseExceptionGroup[BaseExcT_co]], bool] | None = None, + ) -> None: ... + + # simplify the typevars if possible (the following 3 are equivalent but go simpler->complicated) + # ... the first handles RaisesGroup[ValueError], the second RaisesGroup[ExceptionGroup[ValueError]], + # the third RaisesGroup[ValueError | ExceptionGroup[ValueError]]. + # ... otherwise, we will get results like RaisesGroup[ValueError | ExceptionGroup[Never]] (I think) + # (technically correct but misleading) + @overload + def __init__( + self: RaisesGroup[ExcT_1], + exception: type[ExcT_1] | Matcher[ExcT_1], + *other_exceptions: type[ExcT_1] | Matcher[ExcT_1], + match: str | Pattern[str] | None = None, + check: Callable[[ExceptionGroup[ExcT_1]], bool] | None = None, + ) -> None: ... @overload def __init__( - self, - exception: type[E] | Matcher[E] | E, - *other_exceptions: type[E] | Matcher[E] | E, - allow_unwrapped: Literal[False] = False, - flatten_subgroups: Literal[False] = False, + self: RaisesGroup[ExceptionGroup[ExcT_2]], + exception: RaisesGroup[ExcT_2], + *other_exceptions: RaisesGroup[ExcT_2], match: str | Pattern[str] | None = None, - check: Callable[[BaseExceptionGroup[E]], bool] | None = None, - ): ... + check: Callable[[ExceptionGroup[ExceptionGroup[ExcT_2]]], bool] | None = None, + ) -> None: ... + @overload def __init__( - self, - exception: type[E] | Matcher[E] | E, - *other_exceptions: type[E] | Matcher[E] | E, + self: RaisesGroup[ExcT_1 | ExceptionGroup[ExcT_2]], + exception: type[ExcT_1] | Matcher[ExcT_1] | RaisesGroup[ExcT_2], + *other_exceptions: type[ExcT_1] | Matcher[ExcT_1] | RaisesGroup[ExcT_2], + match: str | Pattern[str] | None = None, + check: ( + Callable[[ExceptionGroup[ExcT_1 | ExceptionGroup[ExcT_2]]], bool] | None + ) = None, + ) -> None: ... + + # same as the above 3 but handling BaseException + @overload + def __init__( + self: RaisesGroup[BaseExcT_1], + exception: type[BaseExcT_1] | Matcher[BaseExcT_1], + *other_exceptions: type[BaseExcT_1] | Matcher[BaseExcT_1], + match: str | Pattern[str] | None = None, + check: Callable[[BaseExceptionGroup[BaseExcT_1]], bool] | None = None, + ) -> None: ... + + @overload + def __init__( + self: RaisesGroup[BaseExceptionGroup[BaseExcT_2]], + exception: RaisesGroup[BaseExcT_2], + *other_exceptions: RaisesGroup[BaseExcT_2], + match: str | Pattern[str] | None = None, + check: ( + Callable[[BaseExceptionGroup[BaseExceptionGroup[BaseExcT_2]]], bool] | None + ) = None, + ) -> None: ... + + @overload + def __init__( + self: RaisesGroup[BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]], + exception: type[BaseExcT_1] | Matcher[BaseExcT_1] | RaisesGroup[BaseExcT_2], + *other_exceptions: type[BaseExcT_1] + | Matcher[BaseExcT_1] + | RaisesGroup[BaseExcT_2], + match: str | Pattern[str] | None = None, + check: ( + Callable[ + [BaseExceptionGroup[BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]]], + bool, + ] + | None + ) = None, + ) -> None: ... + + def __init__( + self: RaisesGroup[ExcT_1 | BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]], + exception: type[BaseExcT_1] | Matcher[BaseExcT_1] | RaisesGroup[BaseExcT_2], + *other_exceptions: type[BaseExcT_1] + | Matcher[BaseExcT_1] + | RaisesGroup[BaseExcT_2], allow_unwrapped: bool = False, flatten_subgroups: bool = False, match: str | Pattern[str] | None = None, - check: Callable[[BaseExceptionGroup[E]], bool] | None = None, - strict: None = None, + check: ( + Callable[[BaseExceptionGroup[BaseExcT_1]], bool] + | Callable[[ExceptionGroup[ExcT_1]], bool] + | None + ) = None, ): - self.expected_exceptions: tuple[type[E] | Matcher[E] | E, ...] = ( + self.expected_exceptions: tuple[ + type[BaseExcT_co] | Matcher[BaseExcT_co] | RaisesGroup[BaseException], + ..., + ] = ( exception, *other_exceptions, ) @@ -389,27 +436,18 @@ def __init__( self.check = check self.is_baseexceptiongroup = False - if strict is not None: - warn_deprecated( - "The `strict` parameter", - "0.25.1", - issue=2989, - instead="flatten_subgroups=True (for strict=False}", - ) - self.flatten_subgroups = not strict - if allow_unwrapped and other_exceptions: raise ValueError( "You cannot specify multiple exceptions with `allow_unwrapped=True.`" " If you want to match one of multiple possible exceptions you should" " use a `Matcher`." - " E.g. `Matcher(check=lambda e: isinstance(e, (...)))`" + " E.g. `Matcher(check=lambda e: isinstance(e, (...)))`", ) if allow_unwrapped and isinstance(exception, RaisesGroup): raise ValueError( "`allow_unwrapped=True` has no effect when expecting a `RaisesGroup`." " You might want it in the expected `RaisesGroup`, or" - " `flatten_subgroups=True` if you don't care about the structure." + " `flatten_subgroups=True` if you don't care about the structure.", ) if allow_unwrapped and (match is not None or check is not None): raise ValueError( @@ -418,7 +456,7 @@ def __init__( " exception you should use a `Matcher` object. If you want to match/check" " the exceptiongroup when the exception *is* wrapped you need to" " do e.g. `if isinstance(exc.value, ExceptionGroup):" - " assert RaisesGroup(...).matches(exc.value)` afterwards." + " assert RaisesGroup(...).matches(exc.value)` afterwards.", ) # verify `expected_exceptions` and set `self.is_baseexceptiongroup` @@ -429,7 +467,7 @@ def __init__( "You cannot specify a nested structure inside a RaisesGroup with" " `flatten_subgroups=True`. The parameter will flatten subgroups" " in the raised exceptiongroup before matching, which would never" - " match a nested structure." + " match a nested structure.", ) self.is_baseexceptiongroup |= exc.is_baseexceptiongroup elif isinstance(exc, Matcher): @@ -439,22 +477,35 @@ def __init__( continue # Matcher __init__ assures it's a subclass of BaseException self.is_baseexceptiongroup |= not issubclass( - exc.exception_type, Exception + exc.exception_type, + Exception, ) elif isinstance(exc, type) and issubclass(exc, BaseException): self.is_baseexceptiongroup |= not issubclass(exc, Exception) else: raise ValueError( f'Invalid argument "{exc!r}" must be exception type, Matcher, or' - " RaisesGroup." + " RaisesGroup.", ) - def __enter__(self) -> ExceptionInfo[BaseExceptionGroup[E]]: - self.excinfo: ExceptionInfo[BaseExceptionGroup[E]] = ExceptionInfo.for_later() + @overload + def __enter__( + self: RaisesGroup[ExcT_1], + ) -> ExceptionInfo[ExceptionGroup[ExcT_1]]: ... + @overload + def __enter__( + self: RaisesGroup[BaseExcT_1], + ) -> ExceptionInfo[BaseExceptionGroup[BaseExcT_1]]: ... + + def __enter__(self) -> ExceptionInfo[BaseExceptionGroup[BaseException]]: + self.excinfo: ExceptionInfo[BaseExceptionGroup[BaseExcT_co]] = ( + ExceptionInfo.for_later() + ) return self.excinfo def _unroll_exceptions( - self, exceptions: Sequence[BaseException] + self, + exceptions: Sequence[BaseException], ) -> Sequence[BaseException]: """Used if `flatten_subgroups=True`.""" res: list[BaseException] = [] @@ -466,10 +517,21 @@ def _unroll_exceptions( res.append(exc) return res + @overload + def matches( + self: RaisesGroup[ExcT_1], + exc_val: BaseException | None, + ) -> TypeGuard[ExceptionGroup[ExcT_1]]: ... + @overload + def matches( + self: RaisesGroup[BaseExcT_1], + exc_val: BaseException | None, + ) -> TypeGuard[BaseExceptionGroup[BaseExcT_1]]: ... + def matches( self, exc_val: BaseException | None, - ) -> TypeGuard[BaseExceptionGroup[E]]: + ) -> TypeGuard[BaseExceptionGroup[BaseExcT_co]]: """Check if an exception matches the requirements of this RaisesGroup. Example:: @@ -498,11 +560,10 @@ def matches( return False if self.match_expr is not None and not re.search( - self.match_expr, _stringify_exception(exc_val) + self.match_expr, + _stringify_exception(exc_val), ): return False - if self.check is not None and not self.check(exc_val): - return False remaining_exceptions = list(self.expected_exceptions) actual_exceptions: Sequence[BaseException] = exc_val.exceptions @@ -513,9 +574,6 @@ def matches( if len(actual_exceptions) != len(self.expected_exceptions): return False - # it should be possible to get RaisesGroup.matches typed so as not to - # need type: ignore, but I'm not sure that's possible while also having it - # transparent for the end user. for e in actual_exceptions: for rem_e in remaining_exceptions: if ( @@ -523,11 +581,14 @@ def matches( or (isinstance(rem_e, RaisesGroup) and rem_e.matches(e)) or (isinstance(rem_e, Matcher) and rem_e.matches(e)) ): - remaining_exceptions.remove(rem_e) # type: ignore[arg-type] + remaining_exceptions.remove(rem_e) break else: return False - return True + + # only run `self.check` once we know `exc_val` is correct. (see the types) + # unfortunately mypy isn't smart enough to recognize the above `for`s as narrowing. + return self.check is None or self.check(exc_val) # type: ignore[arg-type] def __exit__( self, @@ -548,7 +609,7 @@ def __exit__( # Cast to narrow the exception type now that it's verified. exc_info = cast( - "tuple[type[BaseExceptionGroup[E]], BaseExceptionGroup[E], types.TracebackType]", + "tuple[type[BaseExceptionGroup[BaseExcT_co]], BaseExceptionGroup[BaseExcT_co], types.TracebackType]", (exc_type, exc_val, exc_tb), ) self.excinfo.fill_unfilled(exc_info) diff --git a/src/trio/testing/_sequencer.py b/src/trio/testing/_sequencer.py index 2bade1b315..32171cb2a2 100644 --- a/src/trio/testing/_sequencer.py +++ b/src/trio/testing/_sequencer.py @@ -13,7 +13,7 @@ @_util.final -@attrs.define(eq=False, hash=False, slots=False) +@attrs.define(eq=False, slots=False) class Sequencer: """A convenience class for forcing code in different tasks to run in an explicit linear order. @@ -55,7 +55,8 @@ async def main(): """ _sequence_points: defaultdict[int, Event] = attrs.field( - factory=lambda: defaultdict(Event), init=False + factory=lambda: defaultdict(Event), + init=False, ) _claimed: set[int] = attrs.field(factory=set, init=False) _broken: bool = attrs.field(default=False, init=False) @@ -75,7 +76,7 @@ async def __call__(self, position: int) -> AsyncIterator[None]: for event in self._sequence_points.values(): event.set() raise RuntimeError( - "Sequencer wait cancelled -- sequence broken" + "Sequencer wait cancelled -- sequence broken", ) from None else: if self._broken: diff --git a/src/trio/testing/_trio_test.py b/src/trio/testing/_trio_test.py index a57c0ee4c7..226e559196 100644 --- a/src/trio/testing/_trio_test.py +++ b/src/trio/testing/_trio_test.py @@ -42,7 +42,9 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: raise ValueError("too many clocks spoil the broth!") instruments = [i for i in kwargs.values() if isinstance(i, Instrument)] return _core.run( - partial(fn, *args, **kwargs), clock=clock, instruments=instruments + partial(fn, *args, **kwargs), + clock=clock, + instruments=instruments, ) return wrapper diff --git a/test-requirements.in b/test-requirements.in index af8a751b13..809e171e3b 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -6,22 +6,23 @@ pyright pyOpenSSL >= 22.0.0 # for the ssl + DTLS tests trustme # for the ssl + DTLS tests pylint # for pylint finding all symbols tests -jedi # for jedi code completion tests +jedi; implementation_name == "cpython" # for jedi code completion tests cryptography>=41.0.0 # cryptography<41 segfaults on pypy3.10 # Tools black; implementation_name == "cpython" -mypy; implementation_name == "cpython" -types-pyOpenSSL; implementation_name == "cpython" # and annotations -ruff >= 0.4.3 +mypy # Would use mypy[faster-cache], but orjson has build issues on pypy +orjson; implementation_name == "cpython" +ruff >= 0.8.0 astor # code generation uv >= 0.2.24 codespell # https://github.com/python-trio/trio/pull/654#issuecomment-420518745 -mypy-extensions; implementation_name == "cpython" +mypy-extensions typing-extensions -types-cffi; implementation_name == "cpython" +types-cffi +types-pyOpenSSL # annotations in doc files types-docutils sphinx diff --git a/test-requirements.txt b/test-requirements.txt index 3dec7a570b..87b3c581eb 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,119 +1,119 @@ # This file was autogenerated by uv via the following command: -# uv pip compile --universal --python-version=3.8 test-requirements.in -o test-requirements.txt -alabaster==0.7.13 +# uv pip compile --universal --python-version=3.9 test-requirements.in -o test-requirements.txt +alabaster==0.7.16 # via sphinx astor==0.8.1 # via -r test-requirements.in -astroid==3.2.2 +astroid==3.3.5 # via pylint async-generator==1.10 # via -r test-requirements.in -attrs==23.2.0 +attrs==24.2.0 # via # -r test-requirements.in # outcome -babel==2.15.0 +babel==2.16.0 # via sphinx -black==24.4.2 ; implementation_name == 'cpython' +black==24.10.0 ; implementation_name == 'cpython' # via -r test-requirements.in -certifi==2024.7.4 +certifi==2024.8.30 # via requests -cffi==1.17.0rc1 ; os_name == 'nt' or platform_python_implementation != 'PyPy' +cffi==1.17.1 ; platform_python_implementation != 'PyPy' or os_name == 'nt' # via # -r test-requirements.in # cryptography -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 # via requests click==8.1.7 ; implementation_name == 'cpython' # via black codespell==2.3.0 # via -r test-requirements.in -colorama==0.4.6 ; sys_platform == 'win32' or (implementation_name == 'cpython' and platform_system == 'Windows') +colorama==0.4.6 ; (implementation_name != 'cpython' and sys_platform == 'win32') or (platform_system != 'Windows' and sys_platform == 'win32') or (implementation_name == 'cpython' and platform_system == 'Windows') # via # click # pylint # pytest # sphinx -coverage==7.5.4 +coverage==7.6.8 # via -r test-requirements.in -cryptography==42.0.8 +cryptography==43.0.3 # via # -r test-requirements.in # pyopenssl # trustme # types-pyopenssl -dill==0.3.8 +dill==0.3.9 # via pylint -docutils==0.20.1 +docutils==0.21.2 # via sphinx -exceptiongroup==1.2.1 ; python_version < '3.11' +exceptiongroup==1.2.2 ; python_full_version < '3.11' # via # -r test-requirements.in # pytest -idna==3.7 +idna==3.10 # via # -r test-requirements.in # requests # trustme imagesize==1.4.1 # via sphinx -importlib-metadata==8.0.0 ; python_version < '3.10' +importlib-metadata==8.5.0 ; python_full_version < '3.10' # via sphinx iniconfig==2.0.0 # via pytest isort==5.13.2 # via pylint -jedi==0.19.1 +jedi==0.19.2 ; implementation_name == 'cpython' # via -r test-requirements.in jinja2==3.1.4 # via sphinx -markupsafe==2.1.5 +markupsafe==3.0.2 # via jinja2 mccabe==0.7.0 # via pylint -mypy==1.11.0 ; implementation_name == 'cpython' +mypy==1.13.0 # via -r test-requirements.in -mypy-extensions==1.0.0 ; implementation_name == 'cpython' +mypy-extensions==1.0.0 # via # -r test-requirements.in # black # mypy nodeenv==1.9.1 # via pyright +orjson==3.10.12 ; implementation_name == 'cpython' + # via -r test-requirements.in outcome==1.3.0.post0 # via -r test-requirements.in -packaging==24.1 +packaging==24.2 # via # black # pytest # sphinx -parso==0.8.4 +parso==0.8.4 ; implementation_name == 'cpython' # via jedi pathspec==0.12.1 ; implementation_name == 'cpython' # via black -platformdirs==4.2.2 +platformdirs==4.3.6 # via # black # pylint pluggy==1.5.0 # via pytest -pycparser==2.22 ; os_name == 'nt' or platform_python_implementation != 'PyPy' +pycparser==2.22 ; platform_python_implementation != 'PyPy' or os_name == 'nt' # via cffi pygments==2.18.0 # via sphinx -pylint==3.2.5 +pylint==3.3.1 # via -r test-requirements.in -pyopenssl==24.1.0 +pyopenssl==24.2.1 # via -r test-requirements.in -pyright==1.1.370 +pyright==1.1.389 # via -r test-requirements.in -pytest==8.2.2 +pytest==8.3.3 # via -r test-requirements.in -pytz==2024.1 ; python_version < '3.9' - # via babel requests==2.32.3 # via sphinx -ruff==0.5.1 +ruff==0.8.2 # via -r test-requirements.in sniffio==1.3.1 # via -r test-requirements.in @@ -121,39 +121,40 @@ snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 # via -r test-requirements.in -sphinx==7.1.2 +sphinx==7.4.7 # via -r test-requirements.in -sphinxcontrib-applehelp==1.0.4 +sphinxcontrib-applehelp==2.0.0 # via sphinx -sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-devhelp==2.0.0 # via sphinx -sphinxcontrib-htmlhelp==2.0.1 +sphinxcontrib-htmlhelp==2.1.0 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==1.0.3 +sphinxcontrib-qthelp==2.0.0 # via sphinx -sphinxcontrib-serializinghtml==1.1.5 +sphinxcontrib-serializinghtml==2.0.0 # via sphinx -tomli==2.0.1 ; python_version < '3.11' +tomli==2.2.1 ; python_full_version < '3.11' # via # black # mypy # pylint # pytest -tomlkit==0.12.5 + # sphinx +tomlkit==0.13.2 # via pylint -trustme==1.1.0 +trustme==1.2.0 # via -r test-requirements.in -types-cffi==1.16.0.20240331 ; implementation_name == 'cpython' +types-cffi==1.16.0.20240331 # via # -r test-requirements.in # types-pyopenssl -types-docutils==0.21.0.20240704 +types-docutils==0.21.0.20241128 # via -r test-requirements.in -types-pyopenssl==24.1.0.20240425 ; implementation_name == 'cpython' +types-pyopenssl==24.1.0.20240722 # via -r test-requirements.in -types-setuptools==70.2.0.20240704 ; implementation_name == 'cpython' +types-setuptools==75.6.0.20241126 # via types-cffi typing-extensions==4.12.2 # via @@ -162,9 +163,10 @@ typing-extensions==4.12.2 # black # mypy # pylint -urllib3==2.2.2 + # pyright +urllib3==2.2.3 # via requests -uv==0.2.26 +uv==0.5.5 # via -r test-requirements.in -zipp==3.19.2 ; python_version < '3.10' +zipp==3.21.0 ; python_full_version < '3.10' # via importlib-metadata diff --git a/tests/_trio_check_attrs_aliases.py b/tests/_trio_check_attrs_aliases.py new file mode 100644 index 0000000000..b4a339dabc --- /dev/null +++ b/tests/_trio_check_attrs_aliases.py @@ -0,0 +1,22 @@ +"""Plugins are executed by Pytest before test modules. + +We use this to monkeypatch attrs.field(), so that we can detect if aliases are used for test_exports. +""" + +from typing import Any + +import attrs + +orig_field = attrs.field + + +def field(**kwargs: Any) -> Any: + original_args = kwargs.copy() + metadata = kwargs.setdefault("metadata", {}) + metadata["trio_original_args"] = original_args + return orig_field(**kwargs) + + +# Mark it as being ours, so the test knows it can actually run. +field.trio_modded = True # type: ignore +attrs.field = field diff --git a/tests/cython/run_test_cython.py b/tests/cython/run_test_cython.py new file mode 100644 index 0000000000..0c4e043b59 --- /dev/null +++ b/tests/cython/run_test_cython.py @@ -0,0 +1,3 @@ +from .test_cython import invoke_main_entry_point + +invoke_main_entry_point() diff --git a/tests/cython/test_cython.pyx b/tests/cython/test_cython.pyx index b836caf90c..77857eec4b 100644 --- a/tests/cython/test_cython.pyx +++ b/tests/cython/test_cython.pyx @@ -19,4 +19,5 @@ async def trio_main() -> None: nursery.start_soon(foo) nursery.start_soon(foo) -trio.run(trio_main) +def invoke_main_entry_point(): + trio.run(trio_main)