-
Notifications
You must be signed in to change notification settings - Fork 160
/
patch_module.py
61 lines (49 loc) · 1.53 KB
/
patch_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from collections import defaultdict
import importlib
import sys
class ModulePatchRegister:
register = defaultdict(list)
@classmethod
def register_patch(cls, mod_name, func):
cls.register[mod_name].append(func)
@classmethod
def is_module_patched(cls, name):
return name in cls.register
@classmethod
def get_module_patches(cls, name):
return cls.register[name]
class PatchMetaPathFinder:
def __init__(self):
self.skip = set()
def find_module(self, name, path):
if name in self.skip:
return None
self.skip.add(name)
return PatchModuleLoader(self)
class PatchModuleLoader:
def __init__(self, finder):
self._finder = finder
def load_module(self, name):
mod = importlib.import_module(name)
if ModulePatchRegister.is_module_patched(name):
for patch in ModulePatchRegister.get_module_patches(name):
patch(mod)
self._finder.skip.remove(name)
return mod
sys.meta_path.insert(0, PatchMetaPathFinder())
def when_importing(modname):
def decorated(func):
if modname in sys.modules:
func(sys.modules[modname])
else:
ModulePatchRegister.register_patch(modname, func)
return decorated
# For demo purpose
@when_importing("threading")
def warn(mod):
print "Warning, you are entering dangerous territory!"
@when_importing("math")
def new_math(mod):
def new_abs(num):
return num if num<0 else -num
mod.abs = new_abs