diff --git a/payload/msda.py b/payload/msda.py index 7b122ec..13750aa 100755 --- a/payload/msda.py +++ b/payload/msda.py @@ -36,7 +36,7 @@ __author__ = 'David G. Rosenberg' __copyright__ = 'Copyright (c), Mac Set Default Apps' __license__ = 'MIT' -__version__ = '1.1.0' +__version__ = '1.1.1' __email__ = 'dgrosenberg@icloud.com' @@ -117,7 +117,12 @@ def get_current_username(): def gather_user_ls_paths(): gathered_users = os.listdir(USER_HOMES_LOCATION) - ls_paths = [ create_user_ls_path(u) for u in gathered_users ] + ls_paths = [] + for user in gathered_users: + user_ls_path = create_user_ls_path(user) + if os.path.exists(user_ls_path): + ls_paths.append(user_ls_path) + return ls_paths diff --git a/tests/tests.py b/tests/tests.py index abff56d..09b2bba 100755 --- a/tests/tests.py +++ b/tests/tests.py @@ -51,16 +51,16 @@ def test_seed_plist_copies_plist_into_tmp(self): self.assertTrue(os.path.exists(tmp_path)) -class TestFunctions(TestCase): - - def setUp(self): - self.tmp = tempfile.mkdtemp(prefix=msda.TMP_PREFIX) - - def tearDown(self): - shutil.rmtree(self.tmp) +class TestFunctions(LaunchServicesTestCase): def test_gather_user_ls_paths(self): fake_user_homes = create_user_homes(3, self.tmp) + for fake_user_home in fake_user_homes: + self.seed_plist( + SIMPLE_BINARY_PLIST, + os.path.join(fake_user_home, msda.PLIST_RELATIVE_LOCATION), + msda.PLIST_NAME, + ) with mock.patch('msda.USER_HOMES_LOCATION', self.tmp): gathered_ls_paths = msda.gather_user_ls_paths() @@ -537,6 +537,46 @@ def test_set_handlers_for_all_existing_users_and_user_template(self, self.assertIn(handler, user_ls.handlers) self.assertIn(handler.app_id, user_ls.app_ids) + def test_set_handlers_for_all_existing_valid_users(self,): + fake_user_home_location = os.path.join(self.tmp, 'Users') + fake_user_homes = create_user_homes(randint(3, 5), fake_user_home_location) + num_invalid_users = randint(1, 2) + handlers = LSHandlerFactory.build_batch(randint(4, 6)) + arguments = ['set', '-feu', handlers[0].app_id] + + for handler in handlers: + if '.' in handler.uti: + arguments.extend(['-u', handler.uti, handler.role]) + else: + arguments.extend(['-p', handler.uti]) + + for user_home in fake_user_homes[:-num_invalid_users]: + user_ls_path = self.seed_plist( + SIMPLE_BINARY_PLIST, + os.path.join(user_home, msda.PLIST_RELATIVE_LOCATION), + msda.PLIST_NAME, + ) + user_ls = msda.LaunchServices(user_ls_path) + for handler in handlers: + self.assertNotIn(handler, user_ls.handlers) + self.assertNotIn(handler.app_id, user_ls.app_ids) + + with mock.patch('msda.USER_HOMES_LOCATION', fake_user_home_location): + msda.main(arguments) + + for user_home in fake_user_homes[:-num_invalid_users]: + user_ls.read() + for handler in handlers: + self.assertIn(handler, user_ls.handlers) + self.assertIn(handler.app_id, user_ls.app_ids) + + for user_home in fake_user_homes[-num_invalid_users:]: + self.assertFalse(os.path.exists(os.path.join( + user_home, + msda.PLIST_RELATIVE_LOCATION, + msda.PLIST_NAME, + ))) + if __name__ == '__main__': unittest.main()