From 75eb3b38352943a75dc87fef16f7f50cdd1cc11f Mon Sep 17 00:00:00 2001 From: Simon Rey Date: Thu, 5 Oct 2023 12:09:50 +0200 Subject: [PATCH] Fixe for approvalprofile generator --- pabutools/election/profile/approvalprofile.py | 5 ++-- .../satisfaction/satisfactionmeasure.py | 23 +++++++++++++++++++ .../satisfaction/satisfactionprofile.py | 16 ++++++++++++- pabutools/rules/greedywelfare.py | 9 ++++++-- pabutools/rules/maxwelfare.py | 1 - tests/test_profile.py | 6 ++--- tests/test_rule.py | 10 ++++---- 7 files changed, 55 insertions(+), 15 deletions(-) diff --git a/pabutools/election/profile/approvalprofile.py b/pabutools/election/profile/approvalprofile.py index 56c39508..712f6214 100644 --- a/pabutools/election/profile/approvalprofile.py +++ b/pabutools/election/profile/approvalprofile.py @@ -316,7 +316,7 @@ def get_random_approval_profile(instance: Instance, num_agents: int) -> Approval def get_all_approval_profiles( instance: Instance, num_agents: int -) -> Generator[Iterable[Project]]: +) -> Generator[ApprovalProfile]: """ Returns a generator over all the possible profile for a given instance of a given length. @@ -332,7 +332,8 @@ def get_all_approval_profiles( Generator[Iterable[:py:class:`~pabutools.election.instance.Project`]] Generator over subsets of projects. """ - return product(powerset(instance), repeat=num_agents) + for p in product(powerset(instance), repeat=num_agents): + yield ApprovalProfile([ApprovalBallot(b) for b in p], instance=instance) class ApprovalMultiProfile(MultiProfile, AbstractApprovalProfile): diff --git a/pabutools/election/satisfaction/satisfactionmeasure.py b/pabutools/election/satisfaction/satisfactionmeasure.py index db21040e..bf4e3f3b 100644 --- a/pabutools/election/satisfaction/satisfactionmeasure.py +++ b/pabutools/election/satisfaction/satisfactionmeasure.py @@ -140,3 +140,26 @@ def total_satisfaction(self, projects: Iterable[Project]) -> Number: for sat in self: res += sat.sat(projects) * self.multiplicity(sat) return res + + @abstractmethod + def remove_satisfied( + self, sat_bound: dict[AbstractBallot, Number], projects: Iterable[Project] + ) -> GroupSatisfactionMeasure: + """ + Returns a new satisfaction profile excluding the satisfaction measurs corresponding to satisfied voters, i.e., + who have met or exceeded their satisfaction bound for a given collection of projects. + + Parameters + ---------- + sat_bound : dict[str, Number] + A dictionary of ballot names to numbers, specifying for each ballot the satisfaction bound above which + the voter is considered satisfied. Note that the keys are ballot names, and that nothing ensures ballot + names to be unique, so be careful here. + projects : Iterable[:py:class:`~pabutools.election.instance.Project`] + The collection of projects. + + Returns + ------- + :py:class:`~pabutools.election.satisfaction.satisfactionmeasure.GroupSatisfactionMeasure` + The new satisfaction profile. + """ diff --git a/pabutools/election/satisfaction/satisfactionprofile.py b/pabutools/election/satisfaction/satisfactionprofile.py index 9995a4f1..6b17c4c2 100644 --- a/pabutools/election/satisfaction/satisfactionprofile.py +++ b/pabutools/election/satisfaction/satisfactionprofile.py @@ -5,12 +5,14 @@ from collections import Counter from collections.abc import Iterable +from numbers import Number from pabutools.election.satisfaction.satisfactionmeasure import ( SatisfactionMeasure, GroupSatisfactionMeasure, ) -from pabutools.election.instance import Instance +from pabutools.election.instance import Instance, Project +from pabutools.election.ballot.ballot import AbstractBallot from typing import TYPE_CHECKING @@ -116,6 +118,12 @@ def multiplicity(self, sat: SatisfactionMeasure) -> int: """ return 1 + def remove_satisfied(self, sat_bound: dict[str, Number], + projects: Iterable[Project]) -> SatisfactionProfile: + res = SatisfactionProfile((s for s in self if s.sat(projects) < sat_bound[s.ballot.name]), instance=self.instance) + res.sat_class = self.sat_class + return res + @classmethod def _wrap_methods(cls, names): def wrap_method_closure(name): @@ -296,6 +304,12 @@ def extend_from_multiprofile( def multiplicity(self, sat: SatisfactionMeasure) -> int: return self[sat] + def remove_satisfied(self, sat_bound: dict[AbstractBallot, Number], + projects: Iterable[Project]) -> SatisfactionMultiProfile: + res = SatisfactionMultiProfile({s: m for s, m in self.items() if s.sat(projects) < sat_bound[s.ballot.name]}, instance=self.instance) + res.sat_class = self.sat_class + return res + @classmethod def _wrap_methods(cls, names): def wrap_method_closure(name): diff --git a/pabutools/rules/greedywelfare.py b/pabutools/rules/greedywelfare.py index 82f54f5a..3f2bcc0e 100644 --- a/pabutools/rules/greedywelfare.py +++ b/pabutools/rules/greedywelfare.py @@ -4,7 +4,9 @@ from copy import copy from collections.abc import Iterable from math import inf +from numbers import Number +from pabutools.election import AbstractBallot from pabutools.election.profile import AbstractProfile from pabutools.fractions import frac @@ -24,6 +26,7 @@ def greedy_utilitarian_scheme( budget_allocation: Iterable[Project], tie_breaking: TieBreakingRule, resoluteness: bool = True, + sat_bounds: dict[AbstractBallot, Number] = None, ) -> Iterable[Project] | Iterable[Iterable[Project]]: """ The inner algorithm for the greedy rule. It selects projects in rounds, each time selecting a project that @@ -71,7 +74,8 @@ def aux(inst, prof, sats, allocs, alloc, tie, resolute): new_alloc = copy(alloc) + [project] if project.cost > 0: total_marginal_score = frac( - sats.total_satisfaction(new_alloc) - sats.total_satisfaction(alloc), + sats.total_satisfaction(new_alloc) + - sats.total_satisfaction(alloc), project.cost, ) else: @@ -164,6 +168,7 @@ def satisfaction_density(proj): return frac(total_sat, proj.cost) return inf return 0 + # We sort based on a tuple to ensure ties are broken as intended ordered_projects = sorted( projects, key=lambda p: (-satisfaction_density(p), projects.index(p)) @@ -233,7 +238,7 @@ def greedy_utilitarian_welfare( budget_allocation = [] if sat_class is None: if sat_profile is None: - raise ValueError("Satisfaction and sat_profile cannot both be None.") + raise ValueError("sat_class and sat_profile cannot both be None.") else: if sat_profile is None: sat_profile = profile.as_sat_profile(sat_class) diff --git a/pabutools/rules/maxwelfare.py b/pabutools/rules/maxwelfare.py index 65c93b64..d4f899ca 100644 --- a/pabutools/rules/maxwelfare.py +++ b/pabutools/rules/maxwelfare.py @@ -150,7 +150,6 @@ def max_additive_utilitarian_welfare( else: if sat_profile is None: sat_profile = profile.as_sat_profile(sat_class=sat_class) - return max_additive_utilitarian_welfare_scheme( instance, sat_profile, budget_allocation, resoluteness=resoluteness ) diff --git a/tests/test_profile.py b/tests/test_profile.py index 765e5bbe..c4a80497 100644 --- a/tests/test_profile.py +++ b/tests/test_profile.py @@ -143,9 +143,9 @@ def test_approval_profile(self): new_inst = Instance( [Project("p1", 1), Project("p2", 1), Project("p3", 1)], budget_limit=3 ) - assert len(set(get_all_approval_profiles(new_inst, 1))) == 8 - assert len(set(get_all_approval_profiles(new_inst, 2))) == 8 * 8 - assert len(set(get_all_approval_profiles(new_inst, 3))) == 8 * 8 * 8 + assert len(list(get_all_approval_profiles(new_inst, 1))) == 8 + assert len(list(get_all_approval_profiles(new_inst, 2))) == 8 * 8 + assert len(list(get_all_approval_profiles(new_inst, 3))) == 8 * 8 * 8 def test_app_multiprofile(self): projects = [Project("p" + str(i), cost=2) for i in range(10)] diff --git a/tests/test_rule.py b/tests/test_rule.py index 0728b04b..752bb6fa 100644 --- a/tests/test_rule.py +++ b/tests/test_rule.py @@ -154,12 +154,10 @@ def dummy_elections(): test_election.irr_results_sat[max_additive_utilitarian_welfare][Cost_Sat] = sorted( [[p[0], p[2]], [p[2]]] ) - test_election.irr_results_sat[max_additive_utilitarian_welfare][Cardinality_Sat] = sorted( - [[p[0], p[2]]] - ) - test_election.irr_results_sat[method_of_equal_shares][Cost_Sat] = sorted( - [[]] - ) + test_election.irr_results_sat[max_additive_utilitarian_welfare][ + Cardinality_Sat + ] = sorted([[p[0], p[2]]]) + test_election.irr_results_sat[method_of_equal_shares][Cost_Sat] = sorted([[]]) test_election.irr_results_sat[method_of_equal_shares][Cardinality_Sat] = sorted( [[p[0]]] )