LwrlO`B4>MP&I{>xRf{gQuO08CX2;0EiHHM!rGzp;9*i)wKFuiU$|83KXb
zs6tCC(u+M6F$T+qqb*$?!1q#<^JVpbqFJ&4VbqxJ0{;MuhxOfdxG6Dr?TbFJFN38N
z5{D5bYd_L~y#ggq8YZmEX7R^|cQBJC8CH
z4&ZKT@CP)&>op2SVk))s0fvkbq=mw=hxd69FU7ht!DnwU78WN>>ZSjkUrCZS!!n+
zGm}Opz6y6xm#-4l&-`}ZmD*l{jm0WvqMe)9;Pzr@QT@3`uW-mi{_NfdK9ggQw~7_{
zJsGq00liZzF=KQ%I0G5#5L**jC!UtzloIb6Umy7HO;mZuDb?!_wCJGdkPnq;|0AXs
z2*Z@qWz^hTm9HLWmRY{w&Yj_whXv4`2At<#?}ZXuMc;J2-s+i`?a8>B{~L*owB|00Rj2S8tD=3-w$ZNY<~>EMe5jJT8vlA`#K(KRU{UXM+5H0_b`F%@`c%7cysOS!nx>s
z&C6fi0-bPO?^OpL+~5{f)s%ftHCo&F}ok==o8*6A9^2+{h)kZ|TCbe>ID<
zvRwsYS$Ar>gV{exlt#Cp-aK7Z9yQ{AkbW^T#LYQE+Sklncs4j8nDyHa;xn(KwlhMl
zRYhU@U7pv$+8>S^Dyna`In{vL{lkbG>C`n71{<s-AdAv0xkQ+K
zOZBZOUW2mVB7t{6joNc5`S
zmtAmEM?fzBpx;#4QMuN?hpSnieSbAC^nF>ph`q0f5lPl}8HSPrvwJ&mnom>*E$#r@
z=D#iP&iyymNd#Zj{v;<)O;2f#wSptv8UL=sxtr133
zgqikJ-)ChwujcPWG<(`cY#{=0zk>bOtX^z0H=)_MMSurQ<=8Gjc-cM5N`O@_NHVv7eK{`dvVV@ptJV`xnlGpmaS(NwWIIhzmnq>7{PsdE}$}k%LdWW
zC|U$6cbg%>%hy~OhwbvDyZcnFzZSd0b1+2=Rj^2)eng8BGne}6>4|(>#hL9wFIn)P
z|28Hf(h4m+c>Mv4aN(WuOaa76GP)Oun>f-KZHzyJd$?4$Kf)O!Vx^qes^?9JbH~@+
z?9rz;DIq7x8?(H%gPCd5K@z;9O~dZqv|M-a)q@0$0t413=6uBB&^%CMLaCNf(^>Gq
zS09w{oYx7iGC}|!Wgf-h$J4B@N{|68WiBNuD-2e#+!*Wr1+{{Iow!|>{^hax{jojZ
zTA_i5z$vc%NlDx;B(2My;3vwero5)ypmBkCDpjS>PW2*?-r5Q&OMP|giJGEvhx1GL
zm-DHbzW1lZf=dG)OC*lyi8Y|OdIlSMn@6lvTlz>RB+LTG9W$E8sRazyOz6)X6tGnr
z>z|1#IuPpjvG;f)_9K6jWqVb7SX2fKItLuRTLUAH2fZa=;3p3R=YG_D9=|bv2qYdY
zvK{=leL`!Ae9Zlt{a}}1=W)Z0`T79Z(o5&4shjGMNTK)V9ifvtwE49kN`IIM2Q<~C
z9uTnefzC8|15V;i2+c*_D<|$X{C$SvC_23UQy-GgLASPDa_Z;ILedphhSka}Jracc
zbU$HNc5nDVtBK?JP#CBp{6_tare?}&Ik)!1k86oFw#s!Vp9avba=R|?mhc888Pjt6
zg!;_Z9{gBDm;2CI;L)$x`SF%^SY_i#SWb`MHCx@8S0NAvJg8q(XuesJ9YYFo2L;79
zV1ToC=$q8{EJr*0AK)WhrzDACW{e-JZ{PN^w1>|a1Uf!gBR~4ij))aM>t1idU|m4=
z3hpRSy$iMDHiCX7O4uLN#jk?`Xw^S;mcHqyyc}Bn+0^Ki!tqo38hPB<$?XJiV$(y)
z>E1>AHCoAQ(674(^Wpv}%}j^4nriXY`&DbIaJ11hdiX&m?_>V|q1^xFiRI3XZS>M}
Vb-fZ6oxCHvrWY+Q;Eh}{{|6&RWaj_?
literal 0
HcmV?d00001
diff --git a/kimm/models/inception_v3.py b/kimm/models/inception_v3.py
index 602c105..8800f53 100644
--- a/kimm/models/inception_v3.py
+++ b/kimm/models/inception_v3.py
@@ -299,8 +299,14 @@ def __init__(
):
kwargs = self.fix_config(kwargs)
if weights == "imagenet":
- has_aux_logits = False
- file_name = "inceptionv3_inception_v3.gluon_in1k.keras"
+ if has_aux_logits:
+ file_name = (
+ "inceptionv3_inception_v3.gluon_in1k_aux_logits.keras"
+ )
+ else:
+ file_name = (
+ "inceptionv3_inception_v3.gluon_in1k_no_aux_logits.keras"
+ )
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
super().__init__(
has_aux_logits,
diff --git a/pyproject.toml b/pyproject.toml
index 17c6ebc..d2eed04 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -33,6 +33,7 @@ tests = [
"black",
"pytest",
"pytest-cov",
+ "coverage",
]
examples = ["opencv-python", "matplotlib"]
diff --git a/tools/convert_inception_v3_from_timm.py b/tools/convert_inception_v3_from_timm.py
index e1acc53..5129b19 100644
--- a/tools/convert_inception_v3_from_timm.py
+++ b/tools/convert_inception_v3_from_timm.py
@@ -17,30 +17,36 @@
timm_model_names = [
"inception_v3.gluon_in1k",
+ "inception_v3.gluon_in1k",
]
keras_model_classes = [
inception_v3.InceptionV3,
+ inception_v3.InceptionV3,
]
+has_aux_logits_list = [True, False]
-for timm_model_name, keras_model_class in zip(
- timm_model_names, keras_model_classes
+for timm_model_name, keras_model_class, has_aux_logits in zip(
+ timm_model_names,
+ keras_model_classes,
+ has_aux_logits_list,
):
"""
Prepare timm model and keras model
"""
input_shape = [299, 299, 3]
torch_model = timm.create_model(
- timm_model_name, pretrained=True, aux_logits=False
+ timm_model_name, pretrained=True, aux_logits=has_aux_logits
)
torch_model = torch_model.eval()
trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict(
torch_model.state_dict()
)
keras_model = keras_model_class(
- has_aux_logits=False,
+ has_aux_logits=has_aux_logits,
input_shape=input_shape,
include_preprocessing=False,
classifier_activation="linear",
+ weights=None,
)
trainable_weights, non_trainable_weights = separate_keras_weights(
keras_model
@@ -129,17 +135,33 @@
np.random.seed(2023)
keras_data = np.random.uniform(size=[1] + input_shape).astype("float32")
torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2]))
- torch_y = torch_model(torch_data)
- keras_y = keras_model(keras_data, training=False)
- torch_y = torch_y.detach().cpu().numpy()
- keras_y = keras.ops.convert_to_numpy(keras_y)
- np.testing.assert_allclose(torch_y, keras_y, atol=1e-5)
+ if has_aux_logits:
+ torch_y = torch_model(torch_data)[0]
+ keras_y = keras_model(keras_data, training=False)[0]
+ torch_y = torch_y.detach().cpu().numpy()
+ keras_y = keras.ops.convert_to_numpy(keras_y)
+ np.testing.assert_allclose(torch_y, keras_y, atol=1e-5)
+ else:
+ torch_y = torch_model(torch_data)
+ keras_y = keras_model(keras_data, training=False)
+ torch_y = torch_y.detach().cpu().numpy()
+ keras_y = keras.ops.convert_to_numpy(keras_y)
+ np.testing.assert_allclose(torch_y, keras_y, atol=1e-5)
print(f"{keras_model_class.__name__}: output matched!")
"""
Save converted model
"""
os.makedirs("exported", exist_ok=True)
- export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras"
+ if has_aux_logits:
+ export_path = (
+ f"exported/{keras_model.name.lower()}_{timm_model_name}_"
+ "aux_logits.keras"
+ )
+ else:
+ export_path = (
+ f"exported/{keras_model.name.lower()}_{timm_model_name}_"
+ "no_aux_logits.keras"
+ )
keras_model.save(export_path)
print(f"Export to {export_path}")