From 2f771ac00d89e3d7452cda654129a7cddd26ed0a Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Sat, 20 Jan 2024 19:16:21 +0800 Subject: [PATCH] Fix `InceptionV3` and add Codecov (#20) * Fix `InceptionV3` * Update README * Add Codecov * Update Codecov * Add Codecov badge --- .github/workflows/actions.yml | 7 ++++ README.md | 11 +++++++ docs/banner/kimm.png | Bin 0 -> 12709 bytes kimm/models/inception_v3.py | 10 ++++-- pyproject.toml | 1 + tools/convert_inception_v3_from_timm.py | 42 ++++++++++++++++++------ 6 files changed, 59 insertions(+), 12 deletions(-) create mode 100644 docs/banner/kimm.png diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 5e20c41..c00266f 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -45,6 +45,13 @@ jobs: - name: Test with pytest run: | pytest + coverage xml -o coverage.xml + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: coverage.xml + flags: kimm,kimm-${{ matrix.backend }} format: name: Check the code format diff --git a/README.md b/README.md index ff38e11..1ae8d2b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,16 @@ + + + # Keras Image Models +
+KIMM + +[![PyPI](https://img.shields.io/pypi/v/kimm)](https://pypi.org/project/kimm/) +[![Contributions Welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/james77777778/kimm/issues) +[![codecov](https://codecov.io/gh/james77777778/kimm/graph/badge.svg?token=eEha1SR80D)](https://codecov.io/gh/james77777778/kimm) +
+ ## Description **K**eras **Im**age **M**odels (`kimm`) is a collection of image models, blocks and layers written in Keras 3. The goal is to offer SOTA models with pretrained weights in a user-friendly manner. diff --git a/docs/banner/kimm.png b/docs/banner/kimm.png new file mode 100644 index 0000000000000000000000000000000000000000..342645ea836da4fa0399ff6103d3f399b2474576 GIT binary patch literal 12709 zcmeHtc|4SD`?u~kDyf7@mKKD;Z7&R=5~8xsVC;j*mSv1>xSL2NRJIV3Y-5bEjf|LE zg)+7WexqN1w&C-1YQrS&$)>lk{5@NPF>Vhs=aeY1Jzhxq-bn{L4&aGQvmaIl;e zJR}eq>;eBRfj4POh;cC7EfnslulA>!hSHy^N@{=Vsr-+B@!s+82^YgdO+EFs)jd_! z-PJsl&bg_&DQT#=xhuK<>7}itq3NNfsqOWrr>e%;e*|q#`EN0o+yNIq)2<(4Qz^f1wzTrkj7^no4*UG6-5ZX$#IenY_1Lj~>!_lF0B@{FXqWh74zolr!mAN=1YGV}}mZxaF7)A4f) z@CM6;D|y1b+>m~ua{A_`78cS<(yITL)ocmSQ~3|s{|{>cT;QL@@G$q=v3O_;cf0u? zXak1n{6}1|4Xf#j{qkE+)3phrvq;Oaa(u*;H$xa zQ_4}4P<(vSe5Mx+ZNi^U<*xKAU5sd-zxq8h?Ssu?+~mCjo-QhT?uz|+`NT;#krQ!U z;Z29lgH``3m43YE?w(zLR3ACWUw!Pt>mL}u>hYaWqw|M%-ZIZHMfbA@7CLp!d!FWn zbT32~PgskME`dWx`FZ`^bKHf6HKWD$*r(foP;y=`#_hvSUQoF#xMMGHkyhguV1DXb*?4xSuWhXmNOlqV44E)!nfV;pElEkq@l3 z`3DEd#-Dr&V}fQ^av4EmRUrv&Z5wpP1{#%zpW!CNp1l&)8JP=u-bsmL5XnnX++pH_ z?tb>My$Fvwlg>q3nDtLN(%Wn7qz`QqCX-Lg+5=uoL*xs)T`5X6em&qCzon5HzZCHB z2S(HC2$V>*ix@k1Y0c7HVpzE`G|@PVT3+K<*;MUUKN1}DvtN=PIP;ZWbX;M{D~BD+ zco%h_EwIdwWF&B*JfpT^q`W4^?sNi z)ZiSYG|7!mK9&5gnz3eq%KJ#WH@s3PB}le3sw_aVoYvaI1(Akx&vo?Y?v^jz5#+ji z2Q@AZ$Iu`WIw@YT(-$IBHr8i`hF>}jGS;WU;V`>}hW@dcM3+tjGTOb=*SKB1=cS6Y z$NHhY71EG0cMQ8emZjgju}X&ao;dg+Xrk=fVO+Hn;$c3M5MJt`_fuNVR0l;{ND>&7 z$$UAFc*qJq9`T)kDOy})iJ$U9eR@=4V|`(&a&ftW(NVp9|6Q?qRwc^lup>cZyLcg* z`>n*FtJ z6OCa~(mmU};GsUQ0%ftbAm<1f%HoK2e$u#lMIf9r78^D;&|uu*y8WvQm&*H;X)F9g zHJ0Ogx32DcGvu8@J@s?uYmULY-D<|3s^V6JNw62ia~z*r(jvs^2?0v_)Fx zC>0t7)1R-XTKJR|nyYltLo+@mT_<4s;t~TgAx|!_LRJfe8X`GmRQ=r;O?lp@U|iF^ z@mCJ5CF4?@#`ovLP)H+TU)cM@J1KY>rDs?DXGM-3*B;TJ+t{ks^=?N4@;JN>Zf zkHe+}xWAkxxuaJl`3GN|Cs@fKTNEf%L`lWuTh2QwhaDhw5{Xq6o3b<|OkUQs_^A>- zxOh~@D|cx2o2F=sLWw{ONLqwI?HYe)(H}c3G{y4ZC{Lknn6fPl5u( z_&xQFP$R5N^MGJb`#5VIvsMmXqP7af=1X-i7X`SBRg~~cSU`%?E!gql|u0QRAJX$l&X@!~h z*tIBIwswSG*{l;tdE)5!;5xy|TzHV3&o3Iv*J75QwnMSlY1}0>r*Wpcy1to;wwAEM zuT@|DJ?zGvOG522zgp9$G$^W z=1h#T&l6;KV$v1l`7>1PNiyr-q;vR%U3?D^(!FkU_}k!7 z{TucYl+%#KDDa2RExSiz>qq1fV}ssTX0(536bwVv@i9% z@WyDcjMX7f<0a+hvy^UKmCCE{K1b1D-z=ntyeHy>?^j`ShI;F;COhN11C}u&*4H@v z8R+g>^!uQ85!RrV^-4|tIRRVU6w~SfsgJ|uBVi%xcqS>Gg){c|%lGyo{JkA8VAUQw zzlLhg&t#RNVvVS}y9+k=F-T@QlsMDeT7Bn*0Ar;)s93{rJQLc1X&1<{`Q| z1F?&u?;Git<~iOlmd(1JSC$r+98}!Lv2aZ8J^f)(9sQyA_Ls)DA}!fF0K@~ZF=ApX z_A^ONLGj`QQ5pM_7h6b+4%+9}W)r42Dv~B9Q~`Yy?M4mDQ)1TV)2FOO@>!KV)v1;f zk~H`s(7nIB(8efS{X8RjKDf#X(bK&7J#0bg#Db`C1mnXGO zDD-v8G$qtpHw`O%&r0;L6y}V#fEUlSm6<*#&xzEN9et$YH^W=oX-G|nZNx}o(~sAC zNA5gK%usxUb!}oQLBl@ZYMm?2mbHnfYf6y1U91y899s4!q#YY@4tl)J^cGjSRil8N zmCv!M>cXBL{u(lHJ-|*qG*@CX;2lkLP4CClbfk0ZMP1AGsq{q=C+nbkdjc__LFAELoUK(O13jSsWTz;6N*Br*&vd=6C{ zX&M8!4wVK=us0Tmnto)1R?}VHS2rzwRW>T!0hk9a1|;uXbeQuU{a|>4*QL{Hg9R$L zo$lG}G*{O>_GFPkhWzN&D7g2wy4G6aaa`22{2&=0aAX6I=>eu_{wJq# zEqH0ZH><-j;HXJc5IOd-w0~8+p%diUZa3mUt)Cp2~8~X7Y1snl%ac#yd zdPXsWEDtUPe;#PVift|A0S3-EWP#V`j<2D;A1OvH4pn!gUxpGw_X}QIow5g^=)en3 zOg(9uL>LKTmPQ?MQLtR+NO=lzsEvA&=Ba=wYGt_hTa5fDT-26-A;8YETFGvfPV2ud zwI+?($)j1-ei(Rg(MHCZj}8g?!AzlGDNgSkh4+#x$u?Tx{A(sY*|QJSGt_@)NxtEQ z!`sOE_wVg;9)Xrc7M+9~+E(|d`3Nj5YZaz|FuGSbv3OhRFv~f>V|AtfYU7ii$FA$1 zfLzX}FQ55V?B3r6iQ)~0{ocLmaI_u@<^pbIP zPxT7B670e=!Kdw#tl@}Mi@w7E>=GpGaI291xW%ILM$vz7>q5@;-_CUUIt(cb+OhW- zkn7>l9pxZOAy&POcD%1W$kh)PlyZ3M`P)4HnQrNndGPSk6JZlfb~HTMqOWj z^5>QfY_HA)Xwd7Ai+UETD;&0E9$TmNj~2`>fb>9XUDKHpHThk-Pnv^^V6~d*fw!wh zJ@y656iY+Gx)5UMa*ubPSGxk0U9x*8+{Y&$#`oMf_1>zi`ZK)T9l zaH;-fU*Z7}KaUw)6tKI}P;F_R=&9*re(efgRDyxOVo1Po*)vPcne|cDBN<9OI0ewd zUgr<><|4JCnC+SuTdI9mz9x<^;<|0t>0~AQ>}$wAqVQvn@(^1vNc6G)#$GOnCb_@JEX!4!ypXyW8%vIUOFIvm9>Z_f;{0)&`;z4u|A=<%YDE{r?9px0` z5Q>3Iu||QFtHV3o&Muu{QGJ7z7v@>xNao98SxC$Sn=!dOpX3kKK%P8AUXZcJ%ibey zj79}!{!z~oZo!E1cx(yJq_6kRcfY7~LJT_#Z70$T+Cld2doWBFp|7`g&#|7B#e_9q zjdriPU1^{j0%Q?2tuDW{egl#GTjTATwg7mbNZJP(iz77w=<8@L#wY&@H@3f-(1e3> zxE<Gd*MuFFF)c2I-WjSQT?%3Re>jC^~As!51syZ6TtRr;*2HmKDV?M60n+-;M6?_-$l zW@Q7#CY{UvuRsW+)HlBPASS|G2>U5&8X65sYZci<{ZR)(bc2<*h%sI=<}8IcsuTOy zO$7k*Cn4+Xf^EyqtM@C*$_u&csrgQwH3HlPS9+oAT9)9{jGd(rC&&Uyn9L^j#QeO% ztf*JkvM7c2RS-L>#auwd5Lg76J2k;M1?l|*2tmes?4bBbRpEkQTk1VT((S0U@dlocP(vNPmW7E(_%?K0UI}Bp)I*kY7Y8qI|^`_r0i5JP+yPIK{B?9RVzjm|X zmclm({m3(Ay&s6P&?qz%&H9>+bzSkVASu{Rkn#AWDK1q->7<2E&X>`f#;QMT;nXTO z(@XgVBi~%2U~aD8!)O>Qv37w}B!N0XxTLUac(C zmBQ>41bpm}Bf-7PDV@=?e?&CDBV(WKmFtAbV`+b`j0FHiKt;q+68f+Um@qU9B(>68 zR2qGY95U@B7YOPqU-M$v9|LusndCZ1%9U9*7F0a+|7 z0F;NLGg;o%eu+wR`?WRSXJ@;vh^yYH`G`ln9zo>?W?$Q$i#PBOC@)k7`>>a%QHkP- zgRKB~=jaYJrbmj1izBTq(M;Ikc&%8b)C^W_qIdo(7vf`*`?7Fd-LfX+A=LxYljo( zdr*}E7Q!)H*1S0{nAUh@gXihB_~`nPK_~sKwPd&=^2Kvih)zSklSq%hW z9lGQM96t3cG6{(xPTaeH}D%9V${uI(y@ek2AuEB}A)G zB$}jNsP7`2*u<;q4vHi}FV1OYW*nubv!TC=`*bvF60pRYMnRr>>6l;!%hukd)ugO7 z;)Ot+(pw^(v^!Bdy=Ma51DgMz;m>_{s+C5ot4jYivmOwGFe>5%keiOMu+`(? zE6l~+)?2PovuCiZgtW~lx34G@uV>QVutpxATUs0)KI+hjD4#clg9m}a+r((DD^S$M z#hwvSEPxECqYT%zaub3}$NI`zRH;!xPIUmBv|Kmp)i_)c|Y8F0w z^HCpLisG_6;K=o*SX8jcF&a=_G6p&J;snlEsjr7(;^0L=uyxr{o1!r8S^IR;1gloF zqxD6%4sN2Y83S+czgoE#ao1@jcOpq!R5z~}x_4tOEScsU;)~Ka`v46d1k#d;TFBB! zYsdg~C?<9g`bMyM@#u#)30xAO&bW@h0cC%Um~NP^OSBXYgKqy>UxoMW%BxS`y}6?C zQ;-;M>uDecq~|f%6!O!OBgVSV<|1*sr9!O=zzJ=yBz-3^B9Dv9B88^@Fc^^b-2Q^N zca~9u7bN*5pVMralS-*w8fDiE*#^~VxQ)bE*Elms(+dZJ$Pz$lGJ%`YncOyN!4=%h zxP7Qyl*}T4k+e9ZNIW`>s3M70I369=DZsKJj}NhCMI~s>SD=Ko__oIBiM$y*ybG); z0dUF)4TA4qfeI;9dvTkvgB|h$fFg}t9<0p|)cw-yi}O$MonxAbcinh1^pJCM6Z18l zuJOl_!9}5oqxL7Csa@VPy0QM1wHBMxJ%x&HKb{&WuG$90$?wfb@rs>(RTG{^igbK% z=Q6;t3qj{C#tp&lGPTD6?@KJt>n}Z}VX>SP7+2xHw>AHIDjyU*(RVwnV*I*kLB9GH z5CJ08^XJC`#}Ny>%`RlHtwg}eapcSw)aA&ttf4X(WM#cX-#H+JvgIyrmliD^)%#D)RRd~AnXwWyP6Xg zRJVA>HH@vCy+i%y*v2JYIc~kq<3M!zGO3ph8H_+k$>g?ogJs_v5$>-Z9N%>(Sb{Zy zU!sTfaCR%ATCUnR&*x5!o0YQ!wO{)4liqzF*g;9D%Gmu06LfFz5-h0_3z#u!YMRZWd{N@pb=Lu7_Rlw^!{lRlEm=LO%lGH#I~a~tcumtT(sJC!IvM7>A0If#v6Ni4U{nim2V zxYKWes?M)R_FnVP=hKcY(*Mc==nS<5j?pHQ^b)~gcm5#R!WuPtNbj6-pPXhTx`6TwLPaSXx>Hp$ zzaU`%M1%cgts9^Fq85T-EcZXnvk4T@KLC}c+lM&WN z2P)L$!Qr4Mj=gw9GrlV$txJ3Qt8$i=>F=tVF~>p9fW}x_Eq(wCmQBWw8yeuPMR5;9{cO3$Ql+rA zzT2Fq$3w4GfpboC=JRKkUhr@Fdr_sU?R#oSPeVdcTeke1u92hM7nCK8<^e;&KE8p9 z31CFR(E9<{(V}Ofrk(s(PlD&KWcZU1OId=zD?~sB7B3@_tGnJXHNJG~F7f(ElR(kZ zVAOYeCMuuaS5SW8Ul~i3&Vd~%LrugdBO9l)i3 z{JS#Y%<5Q0Vtnfrqfz9n?Wcw8o*1o^SbAq$%Wf3?xA%#GSPA;;zub?z1I`W&no6~& z=EdmJ>q4g+xYpg9k<5&76vX0_jQQwa@+>*xOv4;1g63gPnw*>77}a5#NOe^wX@}(w zR#&te5hf}>2ww#?qefQg$e<1G2vA*C0M*1VC*;hpx5Cp76RDC5-*f;@*B=*o|Mp^c z&5|U*fVF#kOu{5A4*73z^OhDEj)GgMPZ2V>&9|McO&$_dGroU6UjgTkBK%mPw%ev% z#D3qYlQkIEino1nx|0v#D&$wy-N>q+k9qX2Z`6OiA=^uEx_hg5xM98@d$h_9rv;3%>qzc!iWK97%S{YDIi89 zXh(VT2r}6*Kt8WgSrXkBCmLm%+S;rHK68{sQn9jiI*5Cv!ui;Kzskp4cr?falHGgS zp~l`sGAom8`Fv_EpEKZxKE4^!)XH5Ee^dEXg<1ezG9rDK&KxJJPmmPC&&;(TG^s9aO~ z72?O|%l!D_g=YcS@wwS)(n}zz^8tmIhcM_(F$bCWDj8p9w4o{H{G_tb5P(JM#l0Ct z34yPw!q-VpAd}taO3<@)FIQ>S{hfD`(ThRMEas(q(4C5xoVd%%4g{Qm&Qz_to|lDi z_J#7K6tI>`)XFD#knXn!L;s?BM+;q78eg%li-on@XI6-Ju&Ok~wOE(E1~-61LN2_s z0EKAxzg+8F%6(~~l->;4=+CznH46G}AA6SnWYOGS!PPeH*6w|K?z1B}dCWAnl_1ljL|1SA5SA^EH3TQXi9 z_Wg<|hWemSlO8|^TR>1b=%3A(7dL7qtEh+Vt>P{&oEyNUxiCmtM27ERd6`A%l#h8< zrn5stl1;1qOl2ti!9VH~0euM^dFI_S&OgE&nfX4ci5i7YUK2Y3q*HZf<7JN*0!ky% zL-2^s94{@eh$o>i0C(}0+o`TYX4W+UE{j4ygJrd zEom;a4EsJP$ZrQO*j^O-@jSIQ^U*F9usP8CNgTtK_cSfFiO(LuO$7VwGEW(;wum?O zQpnMg&{QR_bKLi_y4e%u(8b02rKoNU?TG4i^h^!g}Sdp&#ux(A#JvW+kcjJg8q1$JnNP|FK9S zZk*EqFS9G#6P31HC9LI840iT=Y9jk5;Q2ZT!p9qmnE_9h3PS~I!R2Om69n*uZJC1*ujBy(7i(@RebuKDLTYSo_}Y{jkc@8>s^eO^Yujn6~}>{k&atfmWOrPo#d+7hU4}-eBsd3W@i2J zFmE^_r)*%~BGz^#^Ay>|XL$6DL4suai*~h@u_tv~Jw#yR%XX9rlSpM;kH@#qv5HN? zI-`KxCUVD0KN}tmY@4Z;CR#)+zW1t3KKq)he;ORl84r>YjCa+{G+{RaduUCp;gLwrlO`B4>MP&I{>xRf{gQuO08CX2;0EiHHM!rGzp;9*i)wKFuiU$|83KXb zs6tCC(u+M6F$T+qqb*$?!1q#<^JVpbqFJ&4VbqxJ0{;MuhxOfdxG6Dr?TbFJFN38N z5{D5bYd_L~y#ggq8YZmEX7R^|cQBJC8CH z4&ZKT@CP)&>op2SVk))s0fvkbq=mw=hxd69FU7ht!DnwU78WN>>ZSjkUrCZS!!n+ zGm}Opz6y6xm#-4l&-`}ZmD*l{jm0WvqMe)9;Pzr@QT@3`uW-mi{_NfdK9ggQw~7_{ zJsGq00liZzF=KQ%I0G5#5L**jC!UtzloIb6Umy7HO;mZuDb?!_wCJGdkPnq;|0AXs z2*Z@qWz^hTm9HLWmRY{w&Yj_whXv4`2At<#?}ZXuMc;J2-s+i`?a8>B{~L*owB|00Rj2S8tD=3-w$ZNY<~>EMe5jJT8vlA`#K(KRU{UXM+5H0_b`F%@`c%7cysOS!nx>s z&C6fi0-bPO?^OpL+~5{f)s%ftHCo&F}ok==o8*6A9^2+{h)kZ|TCbe>ID< zvRwsYS$Ar>gV{exlt#Cp-aK7Z9yQ{AkbW^T#LYQE+Sklncs4j8nDyHa;xn(KwlhMl zRYhU@U7pv$+8>S^Dyna`In{vL{lkbG>C`n71{<s-AdAv0xkQ+K zOZBZOUW2mVB7t{6joNc5`S zmtAmEM?fzBpx;#4QMuN?hpSnieSbAC^nF>ph`q0f5lPl}8HSPrvwJ&mnom>*E$#r@ z=D#iP&iyymNd#Zj{v;<)O;2f#wSptv8UL=sxtr133 zgqikJ-)ChwujcPWG<(`cY#{=0zk>bOtX^z0H=)_MMSurQ<=8Gjc-cM5N`O@_NHVv7eK{`dvVV@ptJV`xnlGpmaS(NwWIIhzmnq>7{PsdE}$}k%LdWW zC|U$6cbg%>%hy~OhwbvDyZcnFzZSd0b1+2=Rj^2)eng8BGne}6>4|(>#hL9wFIn)P z|28Hf(h4m+c>Mv4aN(WuOaa76GP)Oun>f-KZHzyJd$?4$Kf)O!Vx^qes^?9JbH~@+ z?9rz;DIq7x8?(H%gPCd5K@z;9O~dZqv|M-a)q@0$0t413=6uBB&^%CMLaCNf(^>Gq zS09w{oYx7iGC}|!Wgf-h$J4B@N{|68WiBNuD-2e#+!*Wr1+{{Iow!|>{^hax{jojZ zTA_i5z$vc%NlDx;B(2My;3vwero5)ypmBkCDpjS>PW2*?-r5Q&OMP|giJGEvhx1GL zm-DHbzW1lZf=dG)OC*lyi8Y|OdIlSMn@6lvTlz>RB+LTG9W$E8sRazyOz6)X6tGnr z>z|1#IuPpjvG;f)_9K6jWqVb7SX2fKItLuRTLUAH2fZa=;3p3R=YG_D9=|bv2qYdY zvK{=leL`!Ae9Zlt{a}}1=W)Z0`T79Z(o5&4shjGMNTK)V9ifvtwE49kN`IIM2Q<~C z9uTnefzC8|15V;i2+c*_D<|$X{C$SvC_23UQy-GgLASPDa_Z;ILedphhSka}Jracc zbU$HNc5nDVtBK?JP#CBp{6_tare?}&Ik)!1k86oFw#s!Vp9avba=R|?mhc888Pjt6 zg!;_Z9{gBDm;2CI;L)$x`SF%^SY_i#SWb`MHCx@8S0NAvJg8q(XuesJ9YYFo2L;79 zV1ToC=$q8{EJr*0AK)WhrzDACW{e-JZ{PN^w1>|a1Uf!gBR~4ij))aM>t1idU|m4= z3hpRSy$iMDHiCX7O4uLN#jk?`Xw^S;mcHqyyc}Bn+0^Ki!tqo38hPB<$?XJiV$(y) z>E1>AHCoAQ(674(^Wpv}%}j^4nriXY`&DbIaJ11hdiX&m?_>V|q1^xFiRI3XZS>M} Vb-fZ6oxCHvrWY+Q;Eh}{{|6&RWaj_? literal 0 HcmV?d00001 diff --git a/kimm/models/inception_v3.py b/kimm/models/inception_v3.py index 602c105..8800f53 100644 --- a/kimm/models/inception_v3.py +++ b/kimm/models/inception_v3.py @@ -299,8 +299,14 @@ def __init__( ): kwargs = self.fix_config(kwargs) if weights == "imagenet": - has_aux_logits = False - file_name = "inceptionv3_inception_v3.gluon_in1k.keras" + if has_aux_logits: + file_name = ( + "inceptionv3_inception_v3.gluon_in1k_aux_logits.keras" + ) + else: + file_name = ( + "inceptionv3_inception_v3.gluon_in1k_no_aux_logits.keras" + ) kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( has_aux_logits, diff --git a/pyproject.toml b/pyproject.toml index 17c6ebc..d2eed04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ tests = [ "black", "pytest", "pytest-cov", + "coverage", ] examples = ["opencv-python", "matplotlib"] diff --git a/tools/convert_inception_v3_from_timm.py b/tools/convert_inception_v3_from_timm.py index e1acc53..5129b19 100644 --- a/tools/convert_inception_v3_from_timm.py +++ b/tools/convert_inception_v3_from_timm.py @@ -17,30 +17,36 @@ timm_model_names = [ "inception_v3.gluon_in1k", + "inception_v3.gluon_in1k", ] keras_model_classes = [ inception_v3.InceptionV3, + inception_v3.InceptionV3, ] +has_aux_logits_list = [True, False] -for timm_model_name, keras_model_class in zip( - timm_model_names, keras_model_classes +for timm_model_name, keras_model_class, has_aux_logits in zip( + timm_model_names, + keras_model_classes, + has_aux_logits_list, ): """ Prepare timm model and keras model """ input_shape = [299, 299, 3] torch_model = timm.create_model( - timm_model_name, pretrained=True, aux_logits=False + timm_model_name, pretrained=True, aux_logits=has_aux_logits ) torch_model = torch_model.eval() trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( torch_model.state_dict() ) keras_model = keras_model_class( - has_aux_logits=False, + has_aux_logits=has_aux_logits, input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model @@ -129,17 +135,33 @@ np.random.seed(2023) keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) - torch_y = torch_model(torch_data) - keras_y = keras_model(keras_data, training=False) - torch_y = torch_y.detach().cpu().numpy() - keras_y = keras.ops.convert_to_numpy(keras_y) - np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) + if has_aux_logits: + torch_y = torch_model(torch_data)[0] + keras_y = keras_model(keras_data, training=False)[0] + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) + else: + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) print(f"{keras_model_class.__name__}: output matched!") """ Save converted model """ os.makedirs("exported", exist_ok=True) - export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" + if has_aux_logits: + export_path = ( + f"exported/{keras_model.name.lower()}_{timm_model_name}_" + "aux_logits.keras" + ) + else: + export_path = ( + f"exported/{keras_model.name.lower()}_{timm_model_name}_" + "no_aux_logits.keras" + ) keras_model.save(export_path) print(f"Export to {export_path}")