diff --git a/docker/mlflow/docker-compose.yml b/docker/mlflow/docker-compose.yml index cb1acbbc8..2b11bb5e2 100644 --- a/docker/mlflow/docker-compose.yml +++ b/docker/mlflow/docker-compose.yml @@ -64,8 +64,6 @@ services: - "5000:5000" environment: - MLFLOW_S3_ENDPOINT_URL=http://minio:9000 - - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} - - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} command: mlflow server --backend-store-uri mysql+pymysql://${MYSQL_USER}:${MYSQL_PASSWORD}@db:3306/${MYSQL_DATABASE} --default-artifact-root s3://mlflow/ --host 0.0.0.0 volumes: diff --git a/params.yaml b/params.yaml index 51cce0865..7a26f0c32 100644 --- a/params.yaml +++ b/params.yaml @@ -9,15 +9,46 @@ active: - noise_removal models: - logistic_regression + - knn + - svm + - gradient_boosting_classifier datasets: - - diabetes + - cifar10 + - click + - covertype - cpu + - diabetes + - fmnist_binary + - mnist_binary + - mnist_multi + - phoneme valuation_methods: - random - loo + - beta_shapley + - tmc_shapley + - classwise_shapley repetitions: - 1 - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 10 + - 11 + - 12 + - 13 + - 14 + - 15 + - 16 + - 17 + - 18 + - 19 + - 20 experiments: point_removal: sampler: default @@ -27,11 +58,36 @@ experiments: metric: accuracy eval_model: logistic_regression len_curve_perc: 0.5 + weighted_accuracy_drop_logistic_regression_balanced: + idx: weighted_metric_drop + metric: accuracy + eval_model: logistic_regression_balanced + len_curve_perc: 0.5 weighted_accuracy_drop_knn: idx: weighted_metric_drop metric: accuracy eval_model: knn len_curve_perc: 0.5 + weighted_accuracy_drop_gradient_boosting_classifier: + idx: weighted_metric_drop + metric: accuracy + eval_model: gradient_boosting_classifier + len_curve_perc: 0.5 + weighted_accuracy_drop_svm: + idx: weighted_metric_drop + metric: accuracy + eval_model: svm + len_curve_perc: 0.5 + weighted_accuracy_drop_svm_balanced: + idx: weighted_metric_drop + metric: accuracy + eval_model: svm_balanced + len_curve_perc: 0.5 + weighted_accuracy_drop_mlp: + idx: weighted_metric_drop + metric: accuracy + eval_model: mlp + len_curve_perc: 0.5 noise_removal: sampler: default @@ -60,6 +116,51 @@ datasets: threshold_y: threshold: 89 + click: + openml_id: 1216 + + covertype: + openml_id: 1596 + + phoneme: + openml_id: 1489 + + fmnist_binary: + openml_id: 40996 + filters: + binarization: + label_zero: '0' + label_one: '1' + preprocessor: + principal_resnet_components: + n_components: 32 + grayscale: true + seed: 101 + + cifar10: + openml_id: 40927 + filters: + binarization: + label_zero: '1' + label_one: '9' + preprocessor: + principal_resnet_components: + n_components: 32 + grayscale: false + seed: 102 + + mnist_binary: + openml_id: 554 + filters: + binarization: + label_zero: '1' + label_one: '7' + preprocessor: + principal_resnet_components: + n_components: 32 + grayscale: true + seed: 103 + mnist_multi: openml_id: 554 preprocessor: @@ -73,11 +174,34 @@ models: model: logistic_regression solver: liblinear + logistic_regression_balanced: + model: logistic_regression + solver: liblinear + class_weight: balanced + + gradient_boosting_classifier: + model: gradient_boosting_classifier + n_estimators: 40 + min_samples_split: 6 + max_depth: 2 + knn: model: knn n_neighbors: 5 weights: uniform + svm: + model: svm + kernel: rbf + + svm_balanced: + model: svm + kernel: rbf + class_weight: balanced + + mlp: + model: mlp + valuation_methods: random: algorithm: random @@ -86,3 +210,53 @@ valuation_methods: algorithm: loo progress: true cache_group: acc + + classwise_shapley: + algorithm: classwise_shapley + cache_group: disc_acc + normalize_values: true + n_resample_complement_sets: 1 + n_updates: 500 + rtol: 1e-4 + progress: true + use_default_scorer_value: false + + + beta_shapley: + algorithm: beta_shapley + alpha: 16.0 + beta: 1.0 + n_updates: 500 + progress: true + cache_group: acc + + + banzhaf_shapley: + algorithm: banzhaf_shapley + n_updates: 500 + progress: true + cache_group: acc + + + tmc_shapley: + algorithm: tmc_shapley + rtol: 1e-4 + n_updates: 500 + progress: true + cache_group: acc + + + owen_sampling_shapley: + algorithm: owen_sampling_shapley + n_updates: 10 + max_q: 50 + progress: false + cache_group: acc + + + least_core: + algorithm: least_core + n_updates: 5000 + progress: false + cache_group: acc + diff --git a/poetry.lock b/poetry.lock index a5587d21e..ab1177dd0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2867,6 +2867,147 @@ files = [ {file = "numpy-1.26.3.tar.gz", hash = "sha256:697df43e2b6310ecc9d95f05d5ef20eacc09c7c4ecc9da3f235d39e71b7da1e4"}, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.1.3.1" +description = "CUBLAS native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" +description = "CUDA profiling tools runtime libs." +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" +description = "NVRTC native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" +description = "CUDA Runtime native Libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "8.9.2.26" +description = "cuDNN runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.0.2.54" +description = "CUFFT native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.2.106" +description = "CURAND native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" +description = "CUDA solver native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" +description = "CUSPARSE native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.18.1" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.3.101" +description = "Nvidia JIT LTO Library" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +description = "NVIDIA Tools Extension" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, +] + [[package]] name = "oauthlib" version = "3.2.2" @@ -4678,77 +4819,91 @@ files = [ [[package]] name = "torch" -version = "2.0.1" +version = "2.1.2" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8ced00b3ba471856b993822508f77c98f48a458623596a4c43136158781e306a"}, - {file = "torch-2.0.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:359bfaad94d1cda02ab775dc1cc386d585712329bb47b8741607ef6ef4950747"}, - {file = "torch-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:7c84e44d9002182edd859f3400deaa7410f5ec948a519cc7ef512c2f9b34d2c4"}, - {file = "torch-2.0.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:567f84d657edc5582d716900543e6e62353dbe275e61cdc36eda4929e46df9e7"}, - {file = "torch-2.0.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:787b5a78aa7917465e9b96399b883920c88a08f4eb63b5a5d2d1a16e27d2f89b"}, - {file = "torch-2.0.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:e617b1d0abaf6ced02dbb9486803abfef0d581609b09641b34fa315c9c40766d"}, - {file = "torch-2.0.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b6019b1de4978e96daa21d6a3ebb41e88a0b474898fe251fd96189587408873e"}, - {file = "torch-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:dbd68cbd1cd9da32fe5d294dd3411509b3d841baecb780b38b3b7b06c7754434"}, - {file = "torch-2.0.1-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:ef654427d91600129864644e35deea761fb1fe131710180b952a6f2e2207075e"}, - {file = "torch-2.0.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:25aa43ca80dcdf32f13da04c503ec7afdf8e77e3a0183dd85cd3e53b2842e527"}, - {file = "torch-2.0.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5ef3ea3d25441d3957348f7e99c7824d33798258a2bf5f0f0277cbcadad2e20d"}, - {file = "torch-2.0.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:0882243755ff28895e8e6dc6bc26ebcf5aa0911ed81b2a12f241fc4b09075b13"}, - {file = "torch-2.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:f66aa6b9580a22b04d0af54fcd042f52406a8479e2b6a550e3d9f95963e168c8"}, - {file = "torch-2.0.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:1adb60d369f2650cac8e9a95b1d5758e25d526a34808f7448d0bd599e4ae9072"}, - {file = "torch-2.0.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:1bcffc16b89e296826b33b98db5166f990e3b72654a2b90673e817b16c50e32b"}, - {file = "torch-2.0.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:e10e1597f2175365285db1b24019eb6f04d53dcd626c735fc502f1e8b6be9875"}, - {file = "torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:423e0ae257b756bb45a4b49072046772d1ad0c592265c5080070e0767da4e490"}, - {file = "torch-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:8742bdc62946c93f75ff92da00e3803216c6cce9b132fbca69664ca38cfb3e18"}, - {file = "torch-2.0.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:c62df99352bd6ee5a5a8d1832452110435d178b5164de450831a3a8cc14dc680"}, - {file = "torch-2.0.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:671a2565e3f63b8fe8e42ae3e36ad249fe5e567435ea27b94edaa672a7d0c416"}, + {file = "torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:3a871edd6c02dae77ad810335c0833391c1a4ce49af21ea8cf0f6a5d2096eea8"}, + {file = "torch-2.1.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076"}, + {file = "torch-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e13034fd5fb323cbbc29e56d0637a3791e50dd589616f40c79adfa36a5a35a1"}, + {file = "torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:d9b535cad0df3d13997dbe8bd68ac33e0e3ae5377639c9881948e40794a61403"}, + {file = "torch-2.1.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:f9a55d55af02826ebfbadf4e9b682f0f27766bc33df8236b48d28d705587868f"}, + {file = "torch-2.1.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:a6ebbe517097ef289cc7952783588c72de071d4b15ce0f8b285093f0916b1162"}, + {file = "torch-2.1.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8f32ce591616a30304f37a7d5ea80b69ca9e1b94bba7f308184bf616fdaea155"}, + {file = "torch-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:e0ee6cf90c8970e05760f898d58f9ac65821c37ffe8b04269ec787aa70962b69"}, + {file = "torch-2.1.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:76d37967c31c99548ad2c4d3f2cf191db48476f2e69b35a0937137116da356a1"}, + {file = "torch-2.1.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e2d83f07b4aac983453ea5bf8f9aa9dacf2278a8d31247f5d9037f37befc60e4"}, + {file = "torch-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f41fe0c7ecbf903a568c73486139a75cfab287a0f6c17ed0698fdea7a1e8641d"}, + {file = "torch-2.1.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e3225f47d50bb66f756fe9196a768055d1c26b02154eb1f770ce47a2578d3aa7"}, + {file = "torch-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33d59cd03cb60106857f6c26b36457793637512998666ee3ce17311f217afe2b"}, + {file = "torch-2.1.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:8e221deccd0def6c2badff6be403e0c53491805ed9915e2c029adbcdb87ab6b5"}, + {file = "torch-2.1.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05b18594f60a911a0c4f023f38a8bda77131fba5fd741bda626e97dcf5a3dd0a"}, + {file = "torch-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ca96253b761e9aaf8e06fb30a66ee301aecbf15bb5a303097de1969077620b6"}, + {file = "torch-2.1.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d93ba70f67b08c2ae5598ee711cbc546a1bc8102cef938904b8c85c2089a51a0"}, + {file = "torch-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:255b50bc0608db177e6a3cc118961d77de7e5105f07816585fa6f191f33a9ff3"}, + {file = "torch-2.1.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6984cd5057c0c977b3c9757254e989d3f1124f4ce9d07caa6cb637783c71d42a"}, + {file = "torch-2.1.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:bc195d7927feabc0eb7c110e457c955ed2ab616f3c7c28439dd4188cf589699f"}, ] [package.dependencies] filelock = "*" +fsspec = "*" jinja2 = "*" networkx = "*" +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.18.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" +triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} typing-extensions = "*" [package.extras] +dynamo = ["jinja2"] opt-einsum = ["opt-einsum (>=3.3)"] [[package]] name = "torchvision" -version = "0.15.2" +version = "0.16.2" description = "image and video datasets and models for torch deep learning" optional = false python-versions = ">=3.8" files = [ - {file = "torchvision-0.15.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7754088774e810c5672b142a45dcf20b1bd986a5a7da90f8660c43dc43fb850c"}, - {file = "torchvision-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37eb138e13f6212537a3009ac218695483a635c404b6cc1d8e0d0d978026a86d"}, - {file = "torchvision-0.15.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:54143f7cc0797d199b98a53b7d21c3f97615762d4dd17ad45a41c7e80d880e73"}, - {file = "torchvision-0.15.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:1eefebf5fbd01a95fe8f003d623d941601c94b5cec547b420da89cb369d9cf96"}, - {file = "torchvision-0.15.2-cp310-cp310-win_amd64.whl", hash = "sha256:96fae30c5ca8423f4b9790df0f0d929748e32718d88709b7b567d2f630c042e3"}, - {file = "torchvision-0.15.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5f35f6bd5bcc4568e6522e4137fa60fcc72f4fa3e615321c26cd87e855acd398"}, - {file = "torchvision-0.15.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:757505a0ab2be7096cb9d2bf4723202c971cceddb72c7952a7e877f773de0f8a"}, - {file = "torchvision-0.15.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:012ad25cfd9019ff9b0714a168727e3845029be1af82296ff1e1482931fa4b80"}, - {file = "torchvision-0.15.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b02a7ffeaa61448737f39a4210b8ee60234bda0515a0c0d8562f884454105b0f"}, - {file = "torchvision-0.15.2-cp311-cp311-win_amd64.whl", hash = "sha256:10be76ceded48329d0a0355ac33da131ee3993ff6c125e4a02ab34b5baa2472c"}, - {file = "torchvision-0.15.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8f12415b686dba884fb086f53ac803f692be5a5cdd8a758f50812b30fffea2e4"}, - {file = "torchvision-0.15.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:31211c01f8b8ec33b8a638327b5463212e79a03e43c895f88049f97af1bd12fd"}, - {file = "torchvision-0.15.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c55f9889e436f14b4f84a9c00ebad0d31f5b4626f10cf8018e6c676f92a6d199"}, - {file = "torchvision-0.15.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9a192f2aa979438f23c20e883980b23d13268ab9f819498774a6d2eb021802c2"}, - {file = "torchvision-0.15.2-cp38-cp38-win_amd64.whl", hash = "sha256:c07071bc8d02aa8fcdfe139ab6a1ef57d3b64c9e30e84d12d45c9f4d89fb6536"}, - {file = "torchvision-0.15.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4790260fcf478a41c7ecc60a6d5200a88159fdd8d756e9f29f0f8c59c4a67a68"}, - {file = "torchvision-0.15.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:987ab62225b4151a11e53fd06150c5258ced24ac9d7c547e0e4ab6fbca92a5ce"}, - {file = "torchvision-0.15.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:63df26673e66cba3f17e07c327a8cafa3cce98265dbc3da329f1951d45966838"}, - {file = "torchvision-0.15.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b85f98d4cc2f72452f6792ab4463a3541bc5678a8cdd3da0e139ba2fe8b56d42"}, - {file = "torchvision-0.15.2-cp39-cp39-win_amd64.whl", hash = "sha256:07c462524cc1bba5190c16a9d47eac1fca024d60595a310f23c00b4ffff18b30"}, + {file = "torchvision-0.16.2-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:bc86f2800cb2c0c1a09c581409cdd6bff66e62f103dc83fc63f73346264c3756"}, + {file = "torchvision-0.16.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b024bd412df6d3a007dcebf311a894eb3c5c21e1af80d12be382bbcb097a7c3a"}, + {file = "torchvision-0.16.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:e89f10f3c8351972b6e3fda95bc3e479ea8dbfc9dfcfd2c32902dbad4ba5cfc5"}, + {file = "torchvision-0.16.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:96c7583700112a410bdc4e1e4f118c429dab49c29c9a31a2cc3579bc9b08b19d"}, + {file = "torchvision-0.16.2-cp310-cp310-win_amd64.whl", hash = "sha256:9f4032ebb3277fb07ff6a9b818d50a547fb8fcd89d958cfd9e773322454bb688"}, + {file = "torchvision-0.16.2-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:67b1aaf8b8cb02ce75dd445f291a27c8036a502f8c0aa76e28c37a0faac2e153"}, + {file = "torchvision-0.16.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bef30d03e1d1c629761f4dca51d3b7d8a0dc0acce6f4068ab2a1634e8e7b64e0"}, + {file = "torchvision-0.16.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:e59cc7b2bd1ab5c0ce4ae382e4e37be8f1c174e8b5de2f6a23c170de9ae28495"}, + {file = "torchvision-0.16.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e130b08cc9b3cc73a6c59d6edf032394a322f9579bfd21d14bc2e1d0999aa758"}, + {file = "torchvision-0.16.2-cp311-cp311-win_amd64.whl", hash = "sha256:8692ab1e48807e9604046a6f4beeb67b523294cee1b00828654bb0df2cfce2b2"}, + {file = "torchvision-0.16.2-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:b82732dcf876a37c852772342aa6ee3480c03bb3e2a802ae109fc5f7e28d26e9"}, + {file = "torchvision-0.16.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4b065143d1a720fe8a9077fd4be35d491f98819ec80b3dbbc3ec64d0b707a906"}, + {file = "torchvision-0.16.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:bc5f274e4ecd1b86062063cdf4fd385a1d39d147a3a2685fbbde9ff08bb720b8"}, + {file = "torchvision-0.16.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:335959c43b371c0474af34c1ef2a52efdc7603c45700d29e4475eeb02984170c"}, + {file = "torchvision-0.16.2-cp38-cp38-win_amd64.whl", hash = "sha256:7fd22d86e08eba321af70cad291020c2cdeac069b00ce88b923ca52e06174769"}, + {file = "torchvision-0.16.2-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:56115268b37f0b75364e3654e47ad9abc66ac34c1f9e5e3dfa89a22d6a40017a"}, + {file = "torchvision-0.16.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:82805f8445b094f9d1e770390ee6cc86855e89955e08ce34af2e2274fc0e5c45"}, + {file = "torchvision-0.16.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3f4bd5fcbc361476e2e78016636ac7d5509e59d9962521f06eb98e6803898182"}, + {file = "torchvision-0.16.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8199acdf8ab066a28b84a5b6f4d97b58976d9e164b1acc3a9d14fccfaf74bb3a"}, + {file = "torchvision-0.16.2-cp39-cp39-win_amd64.whl", hash = "sha256:41dd4fa9f176d563fe9f1b9adef3b7e582cdfb60ce8c9bc51b094a025be687c9"}, ] [package.dependencies] numpy = "*" pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" requests = "*" -torch = "2.0.1" +torch = "2.1.2" [package.extras] scipy = ["scipy"] @@ -4808,6 +4963,31 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "triton" +version = "2.1.0" +description = "A language and compiler for custom Deep Learning operations" +optional = false +python-versions = "*" +files = [ + {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, + {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, + {file = "triton-2.1.0-0-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ae4bb8a91de790e1866405211c4d618379781188f40d5c4c399766914e84cd94"}, + {file = "triton-2.1.0-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39f6fb6bdccb3e98f3152e3fbea724f1aeae7d749412bbb1fa9c441d474eba26"}, + {file = "triton-2.1.0-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21544e522c02005a626c8ad63d39bdff2f31d41069592919ef281e964ed26446"}, + {file = "triton-2.1.0-0-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:143582ca31dd89cd982bd3bf53666bab1c7527d41e185f9e3d8a3051ce1b663b"}, + {file = "triton-2.1.0-0-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:82fc5aeeedf6e36be4e4530cbdcba81a09d65c18e02f52dc298696d45721f3bd"}, + {file = "triton-2.1.0-0-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:81a96d110a738ff63339fc892ded095b31bd0d205e3aace262af8400d40b6fa8"}, +] + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.18)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] + [[package]] name = "typer" version = "0.9.0" @@ -5129,4 +5309,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "~3.10" -content-hash = "ad115920c240c9e66e499d554dff120dc9872810d5581cef745b590ca2dad778" +content-hash = "18ef6e5a8d5464091ddce33bfa5fb78987db53d7506037ccc92170d70aac6ddf" diff --git a/pyproject.toml b/pyproject.toml index e048743c1..e65232e87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,8 +16,8 @@ matplotlib = "^3.6.3" seaborn = "^0.12.2" dvc = "^3.15.2" tqdm = "^4.64.1" -torch = "^2.0.0" -torchvision = "^0.15.1" +torch = "2.1.2" +torchvision = "0.16.2" openml = "^0.13.0" click = "^8.1.3" mlflow = "^2.9.2" diff --git a/scripts/determine_in_out_of_cls_marginal_accuracies.py b/scripts/determine_in_out_of_cls_marginal_accuracies.py index 5a58660aa..c4db3cabd 100644 --- a/scripts/determine_in_out_of_cls_marginal_accuracies.py +++ b/scripts/determine_in_out_of_cls_marginal_accuracies.py @@ -37,7 +37,7 @@ @click.command() @click.option("--experiment-name", type=str, required=True) @click.option("--model-name", type=str, required=True) -@click.option("--valuation-method-name", type=str, default="loo") +@click.option("--valuation-method-name", type=str, default="tmc_shapley") @click.option("--max-plotting-percentage", type=float, default=1e-4) def determine_in_cls_out_of_cls_marginal_accuracies( experiment_name: str, @@ -62,7 +62,7 @@ def determine_in_cls_out_of_cls_marginal_accuracies( def _determine_in_cls_out_of_cls_marginal_accuracies( experiment_name: str, model_name: str = "logistic_regression", - valuation_method_name: str = "loo", + valuation_method_name: str = "tmc_shapley", max_plotting_percentage: float = 1e-4, ): output_dir = Accessor.INFO_PATH / experiment_name diff --git a/scripts/render_plots.py b/scripts/render_plots.py index 67c9d85c8..a09061c9a 100644 --- a/scripts/render_plots.py +++ b/scripts/render_plots.py @@ -104,7 +104,7 @@ def _render_plots(experiment_name: str, model_name: str): ) for method_name in method_names: logger.info(f"Plot histogram for method {method_name} values.") - with plot_histogram(valuation_results, [method_name, "tmc_shapley"]) as fig: + with plot_histogram(valuation_results, [method_name]) as fig: log_figure( fig, output_folder, f"density.{method_name}.svg", "densities" ) diff --git a/src/re_classwise_shapley/metric.py b/src/re_classwise_shapley/metric.py index 21d05cc0f..c2d87767d 100644 --- a/src/re_classwise_shapley/metric.py +++ b/src/re_classwise_shapley/metric.py @@ -10,7 +10,7 @@ init_executor, init_parallel_backend, ) -from pydvl.utils import Dataset, Scorer, Seed, Utility +from pydvl.utils import CacheBackend, Dataset, Scorer, Seed, Utility from pydvl.value.result import ValuationResult from sklearn.metrics import auc from tqdm import tqdm @@ -57,7 +57,8 @@ def metric_weighted_metric_drop( n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, - seed: Optional[Seed] = None, + seed: Seed | None = None, + cache: CacheBackend | None = None, ) -> Tuple[float, pd.Series]: r""" Calculates the weighted reciprocal difference average of a valuation function. The @@ -89,7 +90,7 @@ def metric_weighted_metric_drop( model=model, scorer=Scorer(metric, default=np.nan), catch_errors=True, - enable_cache=True, + cache_backend=cache, ) curve = _curve_score_over_point_removal_or_addition( u_eval, diff --git a/src/re_classwise_shapley/model.py b/src/re_classwise_shapley/model.py index 5f03a51f5..463de9f36 100644 --- a/src/re_classwise_shapley/model.py +++ b/src/re_classwise_shapley/model.py @@ -43,7 +43,7 @@ def instantiate_model( "mlp_classifier": MLPClassifier, } model_class = model_dict[model_idx] - model = maybe_add_argument(model_class, random_state)( + model = maybe_add_argument(model_class, "random_state")( **model_kwargs, random_state=random_state ) pipeline = make_pipeline(StandardScaler(), model)