From 6f88d0e23d73f77e86efddefe310932df80d9552 Mon Sep 17 00:00:00 2001 From: "Jayesh K. Gupta" Date: Mon, 4 Jul 2022 18:02:57 -0700 Subject: [PATCH] Add Lux.jl integration with Jax (#27) * basic support for Lux + Jax * update version * add mixed example as well * fix some typos * make sure only positive seeds * rename to allow more wrappers for other jax frameworks --- Project.toml | 3 +- README.md | 1 + examples/simplelux/Manifest.toml | 519 ++++++++++++++++++++ examples/simplelux/Project.toml | 5 + examples/simplelux/train_ml_explicit.jl | 44 ++ examples/simplelux/train_ml_mix_explicit.jl | 54 ++ src/jax.jl | 9 + src/lux.jl | 30 ++ 8 files changed, 664 insertions(+), 1 deletion(-) create mode 100644 examples/simplelux/Manifest.toml create mode 100644 examples/simplelux/Project.toml create mode 100644 examples/simplelux/train_ml_explicit.jl create mode 100644 examples/simplelux/train_ml_mix_explicit.jl create mode 100644 src/lux.jl diff --git a/Project.toml b/Project.toml index c7f2ecd..fb5d4cc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PyCallChainRules" uuid = "b12ccfe2-7326-416f-9f4f-cd3183bd9fe8" authors = ["rejuvyesh and contributors"] -version = "0.3.2" +version = "0.4.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -28,6 +28,7 @@ julia = "1.6" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/README.md b/README.md index 0c95e34..afd1bf5 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,7 @@ loss(p, x, y) = sum(jlwrap(p, x) .- y) grad, = Zygote.gradient(p->loss(p, input, target), params_jl) ``` +When mixing `jax` and `julia` it's recommended to disable `jax`'s preallocation with setting the environment variable `XLA_PYTHON_CLIENT_PREALLOCATE=false`. ## Current Limitations diff --git a/examples/simplelux/Manifest.toml b/examples/simplelux/Manifest.toml new file mode 100644 index 0000000..c1a3eca --- /dev/null +++ b/examples/simplelux/Manifest.toml @@ -0,0 +1,519 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.7.3" +manifest_format = "2.0" + +[[deps.AbstractFFTs]] +deps = ["ChainRulesCore", "LinearAlgebra"] +git-tree-sha1 = "69f7020bd72f069c219b5e8c236c1fa90d2cb409" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.2.1" + +[[deps.Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "af92965fb30777147966f58acb05da51c5616b5f" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.3.3" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[deps.ArrayInterface]] +deps = ["ArrayInterfaceCore", "Compat", "IfElse", "LinearAlgebra", "Static"] +git-tree-sha1 = "6ccb71b40b04ad69152f1f83d5925de13911417e" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "6.0.19" + +[[deps.ArrayInterfaceCore]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "7d255eb1d2e409335835dc8624c35d97453011eb" +uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" +version = "0.1.14" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.2.0" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.CEnum]] +git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.4.2" + +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] +git-tree-sha1 = "e4e5ece72fa2f108fb20c3c5538a5fa9ef3d668a" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "3.11.0" + +[[deps.ChainRules]] +deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"] +git-tree-sha1 = "b06ed86d99c982cbe9047a45a93ac62d9605a361" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.36.2" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "2dd813e5f2f7eec2d1268c57cf2373d3ee91fcea" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.15.1" + +[[deps.ChangesOfVariables]] +deps = ["ChainRulesCore", "LinearAlgebra", "Test"] +git-tree-sha1 = "1e315e3f4b0b7ce40feded39c73049692126cf53" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.3" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["Dates", "LinearAlgebra", "UUIDs"] +git-tree-sha1 = "924cdca592bc16f14d2f7006754a621735280b74" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.1.0" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" + +[[deps.ComponentArrays]] +deps = ["ArrayInterface", "ChainRulesCore", "LinearAlgebra", "Requires"] +git-tree-sha1 = "7573fc9e81ca1031a1ef80d2dcd1765763068352" +uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +version = "0.12.2" + +[[deps.Conda]] +deps = ["Downloads", "JSON", "VersionParsing"] +git-tree-sha1 = "6e47d11ea2776bc5627421d59cdcc1296c058071" +uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d" +version = "1.7.0" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "59d00b3139a9de4eb961057eabb65ac6522be954" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.4.0" + +[[deps.DLPack]] +deps = ["Requires"] +git-tree-sha1 = "915b0cb087ac4fd84fd0bf8f8b123c4d0b12d552" +uuid = "53c2dc0f-f7d5-43fd-8906-6c0220547083" +version = "0.1.1" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.3" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "28d605d9a0ac17118fe2c5e9ce0fbb76c3ceb120" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.11.0" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.6" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" + +[[deps.ExprTools]] +git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.8" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "246621d23d1f43e3b9c368bf3b72b2331a27c286" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.13.2" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "2f18915445b248731ec5db4e4a17e451020bf21e" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.30" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "a2657dd0f3e8a61dbe70fc7c122038bd33790af5" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.3.0" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "73a4c9447419ce058df716925893e452ba5528ad" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "8.4.0" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "4078d3557ab15dd9fe6a0cf6f65e3d4937e98427" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.0" + +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "47f63159f7cb5d0e5e0cfd2f20454adea429bec9" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.16.1" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "af14a478780ca78d5eb9908b263023096c2b9d64" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.6" + +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "b3364212fb5d870f724876ffcd34dd8ec6d98918" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.7" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + +[[deps.JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.4.1" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.3" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "e7e9184b0bf0158ac4e4aa9daf00041b5909bf1a" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "4.14.0" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] +git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.16+0" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + +[[deps.LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "09e4b894ce6a976c354a69041a04748180d43637" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.15" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.Lux]] +deps = ["Adapt", "CUDA", "ChainRulesCore", "ComponentArrays", "FillArrays", "Functors", "LinearAlgebra", "Markdown", "NNlib", "NNlibCUDA", "Optimisers", "Random", "Requires", "Setfield", "SparseArrays", "Statistics", "Zygote"] +git-tree-sha1 = "32c357dcf390fb2a447efc2af5b2034f7b724fe7" +uuid = "b2108857-7c20-44ae-9111-449ecde12c47" +version = "0.4.7" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.9" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[deps.NNlib]] +deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "1a80840bcdb73de345230328d49767ab115be6f2" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.8.8" + +[[deps.NNlibCUDA]] +deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] +git-tree-sha1 = "e161b835c6aa9e2339c1e72c3d4e39891eac7a4f" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.2.3" + +[[deps.NaNMath]] +git-tree-sha1 = "737a5957f387b17e74d4ad2f440eb330b39a62c5" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.0" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "afb2b39a354025a6db6decd68f2ef5353e8ff1ae" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.2.7" + +[[deps.Parsers]] +deps = ["Dates"] +git-tree-sha1 = "0044b23da09b5608b4ecacb4e5e6c6332f833a7e" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.3.2" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.3.0" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.PyCall]] +deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"] +git-tree-sha1 = "1fc929f47d7c151c839c5fc1375929766fb8edcc" +uuid = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" +version = "1.93.1" + +[[deps.PyCallChainRules]] +deps = ["Adapt", "ChainRulesCore", "DLPack", "FillArrays", "Functors", "PyCall", "Random", "Requires"] +path = "../.." +uuid = "b12ccfe2-7326-416f-9f4f-cd3183bd9fe8" +version = "0.3.2" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA", "Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "afeacaecf4ed1649555a19cb2cad3c141bbc9474" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.5.0" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "Requires"] +git-tree-sha1 = "77172cadd2fdfa0c84c87e3a01215a4ca7723310" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.0.0" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.SpecialFunctions]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "a9e798cae4867e3a41cae2dd9eb60c047f1212db" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.1.6" + +[[deps.Static]] +deps = ["IfElse"] +git-tree-sha1 = "46638763d3a25ad7818a15d441e0c3446a10742d" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.7.5" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] +git-tree-sha1 = "9f8a5dc5944dc7fbbe6eb4180660935653b0a9d9" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.5.0" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "66fe9eb253f910fe8cf161953880cfdaef01cdf0" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.0.1" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "464d64b2510a25e6efe410e7edab14fffdc333df" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.20" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.VersionParsing]] +git-tree-sha1 = "58d6e80b4ee071f5efd07fda82cb9fbe17200868" +uuid = "81def892-9a0e-5fdd-b105-ffc91e053289" +version = "1.3.0" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "3cfdb31b517eec4173584fba2b1aa65daad46e09" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.41" + +[[deps.ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.2" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/examples/simplelux/Project.toml b/examples/simplelux/Project.toml new file mode 100644 index 0000000..bb9a6fe --- /dev/null +++ b/examples/simplelux/Project.toml @@ -0,0 +1,5 @@ +[deps] +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +PyCallChainRules = "b12ccfe2-7326-416f-9f4f-cd3183bd9fe8" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/simplelux/train_ml_explicit.jl b/examples/simplelux/train_ml_explicit.jl new file mode 100644 index 0000000..f659980 --- /dev/null +++ b/examples/simplelux/train_ml_explicit.jl @@ -0,0 +1,44 @@ +using Lux +using Optimisers +using Random +using PyCallChainRules.Jax: LuxStaxWrapper, jax, stax +using Zygote + +input_dim = 4 +output_dim = 2 +hiddendim = 16 +batchsize = 6 + +jax_init_fun, jax_apply_fun = stax.serial(stax.Dense(hiddendim), stax.Relu, + stax.Dense(hiddendim), stax.Relu, + stax.Dense(output_dim)) + +jlmodel = LuxStaxWrapper(jax_init_fun, jax.jit(jax_apply_fun); input_shape=(-1, input_dim)) + +rng = Random.default_rng() + +ps, st = Lux.setup(rng, jlmodel) + + +input = randn(Float32, input_dim, batchsize) |> Lux.gpu +target = randn(Float32, output_dim, batchsize) |> Lux.gpu + +loss(model, x, y, ps, st) = sum(abs2, Lux.apply(model, x, ps, st)[1] .- y) + +@info "before" loss(jlmodel, input, target, ps, st) + +function train(model, ps; nsteps=100) + opt = Optimisers.ADAM(0.01) + state = Optimisers.setup(opt, ps) + for i in 1:nsteps + gs, _ = gradient(ps, input, target) do p, x, y + loss(model, x, y, p, st) + end + state, ps = Optimisers.update(state, ps, gs) + end + return ps +end + +newps = train(jlmodel, ps) + +@info "after" loss(jlmodel, input, target, newps, st) \ No newline at end of file diff --git a/examples/simplelux/train_ml_mix_explicit.jl b/examples/simplelux/train_ml_mix_explicit.jl new file mode 100644 index 0000000..f8fce73 --- /dev/null +++ b/examples/simplelux/train_ml_mix_explicit.jl @@ -0,0 +1,54 @@ +using Lux +using Optimisers +using Random +using Zygote +using PyCallChainRules.Jax: LuxStaxWrapper, jax, stax + +# Note when mixing jax and julia layers, recommended to set +# XLA_PYTHON_CLIENT_PREALLOCATE=false + +input_dim = 4 +output_dim = 2 +hiddendim = 16 +batchsize = 6 + +rng = Random.default_rng() + +input = randn(rng, Float32, input_dim, batchsize) |> Lux.gpu +target = randn(rng, Float32, output_dim, batchsize) |> Lux.gpu + + +jax_init_fun, jax_apply_fun = stax.serial(stax.Dense(hiddendim), stax.Relu, + stax.Dense(hiddendim), stax.Relu, + stax.Dense(output_dim), stax.Relu) + + +# Mix of Lux layers and Jax stax layers +# Note: Lux's optimization don't play well +jlmodel = Chain(Dense(input_dim, input_dim, Lux.relu), + LuxStaxWrapper(jax_init_fun, jax.jit(jax_apply_fun); input_shape=(batchsize, input_dim)), + Dense(output_dim, output_dim); disable_optimizations=true) + +ps, st = Lux.setup(rng, jlmodel) .|> Lux.gpu + + + +loss(model, x, y, ps, st) = sum(abs2, Lux.apply(model, x, ps, st)[1] .- y) + +@info "before" loss(jlmodel, input, target, ps, st) + +function train(model, ps; nsteps=100) + opt = Optimisers.ADAM(0.01) + state = Optimisers.setup(opt, ps) + for i in 1:nsteps + gs, _ = gradient(ps, input, target) do p, x, y + loss(model, x, y, p, st) + end + state, ps = Optimisers.update(state, ps, gs) + end + return ps +end + +newps = train(jlmodel, ps) + +@info "after" loss(jlmodel, input, target, newps, st) \ No newline at end of file diff --git a/src/jax.jl b/src/jax.jl index b2e5dc9..dd70ed1 100644 --- a/src/jax.jl +++ b/src/jax.jl @@ -4,6 +4,7 @@ using PyCall using ChainRulesCore using DLPack using Adapt +using Requires using ..PyCallChainRules: PyAdaptor, fmap @@ -18,6 +19,11 @@ const ispysetup = Ref{Bool}(false) pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject +### XXX: what's a little piracy between us +### allows empty parameter tuples +DLPack.wrap(o::Tuple{}, to_dlpack) = o +DLPack.share(o::Tuple{}, ::Type{PyObject}, from_dlpack) = o + struct JaxFunctionWrapper jaxfn::PyObject @@ -59,6 +65,9 @@ function __init__() ispysetup[] = false #rethrow(err) end + @require Lux = "b2108857-7c20-44ae-9111-449ecde12c47" begin + include("lux.jl") + end end end \ No newline at end of file diff --git a/src/lux.jl b/src/lux.jl new file mode 100644 index 0000000..2e8cda2 --- /dev/null +++ b/src/lux.jl @@ -0,0 +1,30 @@ +import .Lux + +using Random +using DLPack +using PyCall +using Functors + +using PyCallChainRules.Jax: JaxFunctionWrapper, jax, pyto_dlpack + +struct LuxStaxWrapper{N} <: Lux.AbstractExplicitLayer + initfn::PyObject + applyfn::JaxFunctionWrapper + input_shape::NTuple{N, Int} +end + +function LuxStaxWrapper(init::PyObject, apply::PyObject; input_shape::NTuple{N,Int}) where {N} + apply_jl = JaxFunctionWrapper(apply) + return LuxStaxWrapper{N}(init, apply_jl, input_shape) +end + +function Lux.initialparameters(rng::AbstractRNG, l::LuxStaxWrapper) + val = abs(rand(rng, Int32)) + _, params = l.initfn(jax.random.PRNGKey(val), l.input_shape) + params_jl = fmap(x->DLPack.wrap(x, pyto_dlpack), params) + return params_jl +end + +function (model::LuxStaxWrapper)(x, ps, st::NamedTuple) + model.applyfn(ps, x), st +end \ No newline at end of file