From 099c8d66e9f557caebef9f125732df71e27e11dc Mon Sep 17 00:00:00 2001 From: nicoladicicco <93935338+nicoladicicco@users.noreply.github.com> Date: Tue, 6 Aug 2024 07:32:18 +0000 Subject: [PATCH 01/12] add tests --- Manifest.toml | 10 ++++++++++ Project.toml | 11 ++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 Manifest.toml diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..98cd839 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,10 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.4" +manifest_format = "2.0" +project_hash = "6f13a7c5afeb93446b0958337560825a3e93b4da" + +[[deps.TestItems]] +git-tree-sha1 = "42fd9023fef18b9b78c8343a4e2f3813ffbcefcb" +uuid = "1c621080-faea-4a02-84b6-bbd5e436b8fe" +version = "1.0.0" diff --git a/Project.toml b/Project.toml index e727b60..08c28ed 100644 --- a/Project.toml +++ b/Project.toml @@ -3,13 +3,22 @@ uuid = "314c63f5-3dda-4b35-95e7-4cc933f13053" authors = ["Jean-François BAFFIER (@Azzaare)"] version = "0.0.1" +[deps] +TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" + [compat] +Aqua = "0.8" +JET = "0.9" julia = "1.10" +Test = "1.10" +TestItems = "1.0.0" +TestItemRunner = "1.0.0" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" [targets] -test = ["Aqua", "JET", "Test"] +test = ["Aqua", "JET", "Test", "TestItemRunner"] From eca13d30d8d5a5fb4984795e63229c3cea84e285 Mon Sep 17 00:00:00 2001 From: Nicola Di Cicco <93935338+nicoladicicco@users.noreply.github.com> Date: Mon, 2 Sep 2024 14:29:57 +0900 Subject: [PATCH 02/12] Tentative first interface --- Manifest.toml | 645 ++++++++++++++++++++++++++++++++++++++++- Project.toml | 14 +- README.md | 3 + src/llm.jl | 170 +++++++++++ src/prompt.jl | 6 + test/Aqua.jl | 3 + test/JET.jl | 3 + test/TestItemRunner.jl | 3 + test/runtests.jl | 17 +- test/utils.jl | 33 +++ 10 files changed, 881 insertions(+), 16 deletions(-) create mode 100644 src/llm.jl create mode 100644 src/prompt.jl create mode 100644 test/Aqua.jl create mode 100644 test/JET.jl create mode 100644 test/TestItemRunner.jl create mode 100644 test/utils.jl diff --git a/Manifest.toml b/Manifest.toml index 98cd839..77883a6 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,10 +1,651 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.4" +julia_version = "1.10.5" manifest_format = "2.0" -project_hash = "6f13a7c5afeb93446b0958337560825a3e93b4da" +project_hash = "2ab469f27c9700c80fa8dc8e1198e45829d4137c" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.BitFlags]] +git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.9" + +[[deps.CSTParser]] +deps = ["Tokenize"] +git-tree-sha1 = "0157e592151e39fa570645e2b2debcdfb8a0f112" +uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" +version = "3.4.3" + +[[deps.CodeTracking]] +deps = ["InteractiveUtils", "UUIDs"] +git-tree-sha1 = "7eee164f122511d3e4e1ebadb7956939ea7e1c77" +uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" +version = "1.3.6" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.6" + +[[deps.CommonMark]] +deps = ["Crayons", "JSON", "PrecompileTools", "URIs"] +git-tree-sha1 = "532c4185d3c9037c0237546d817858b23cf9e071" +uuid = "a80b9123-70ca-4bc0-993e-6e3bcb318db6" +version = "0.8.12" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.16.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.CompositionalNetworks]] +deps = ["ConstraintCommons", "ConstraintDomains", "Dictionaries", "Distances", "JuliaFormatter", "OrderedCollections", "Random", "TestItems", "Unrolled"] +git-tree-sha1 = "42ea78627a970cc0f4d0707fb87c29a5892a65cc" +uuid = "4b67e4b5-442d-4ef5-b760-3f5df3a57537" +version = "0.5.9" + +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.4.2" + +[[deps.ConstraintCommons]] +deps = ["Dictionaries", "TestItems"] +git-tree-sha1 = "779227189854f846de5f72b518e50dda14c7886b" +uuid = "e37357d9-0691-492f-a822-e5ea6a920954" +version = "0.2.3" + +[[deps.ConstraintDomains]] +deps = ["ConstraintCommons", "Intervals", "PatternFolds", "StatsBase", "TestItems"] +git-tree-sha1 = "02380c829c947c0579864c51affa1646a170d037" +uuid = "5800fd60-8556-4464-8d61-84ebf7a0bedb" +version = "0.3.13" + +[[deps.Constraints]] +deps = ["CompositionalNetworks", "ConstraintCommons", "ConstraintDomains", "DataFrames", "Dictionaries", "MacroTools", "PrettyTables", "TestItems"] +path = "../Constraints" +uuid = "30f324ab-b02d-43f0-b619-e131c61659f7" +version = "0.5.6" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.Dictionaries]] +deps = ["Indexing", "Random", "Serialization"] +git-tree-sha1 = "35b66b6744b2d92c778afd3a88d2571875664a2a" +uuid = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" +version = "0.4.2" + +[[deps.Distances]] +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.11" + + [deps.Distances.extensions] + DistancesChainRulesCoreExt = "ChainRulesCore" + DistancesSparseArraysExt = "SparseArrays" + + [deps.Distances.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.ExceptionUnwrapping]] +deps = ["Test"] +git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.10" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.Glob]] +git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.1" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.10.8" + +[[deps.Indexing]] +git-tree-sha1 = "ce1566720fd6b19ff3411404d4b977acd4814f9f" +uuid = "313cdc1a-70c2-5d6a-ae34-0150d3930a38" +version = "1.1.1" + +[[deps.InlineStrings]] +git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.2" + + [deps.InlineStrings.extensions] + ArrowTypesExt = "ArrowTypes" + ParsersExt = "Parsers" + + [deps.InlineStrings.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.Intervals]] +deps = ["Dates", "Printf", "RecipesBase", "Serialization", "TimeZones"] +git-tree-sha1 = "ac0aaa807ed5eaf13f67afe188ebc07e828ff640" +uuid = "d8418881-c3e1-53bb-8760-2df7ec849ed5" +version = "1.10.0" + +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.6.0" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.4" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.0" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + +[[deps.JuliaFormatter]] +deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "PrecompileTools", "TOML", "Tokenize"] +git-tree-sha1 = "bb4696471330275adfd6c78c6173f276e8c067aa" +uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +version = "1.0.60" + +[[deps.JuliaInterpreter]] +deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] +git-tree-sha1 = "4b415b6cccb9ab61fec78a621572c82ac7fa5776" +uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" +version = "0.9.35" + +[[deps.LaTeXStrings]] +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.1" + +[[deps.Lazy]] +deps = ["MacroTools"] +git-tree-sha1 = "1370f8202dac30758f3c345f9909b97f53d87d3f" +uuid = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" +version = "0.15.1" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.28" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.3" + +[[deps.LoweredCodeUtils]] +deps = ["JuliaInterpreter"] +git-tree-sha1 = "1ce1834f9644a8f7c011eb0592b7fd6c42c90653" +uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" +version = "3.0.1" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] +git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.1.9" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.2.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.Mocking]] +deps = ["Compat", "ExprTools"] +git-tree-sha1 = "2c140d60d7cb82badf06d8783800d0bcd1a7daa2" +uuid = "78c3b35d-d492-501b-9361-3d52fe80e533" +version = "0.8.1" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.3" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.14+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.1" + +[[deps.PatternFolds]] +deps = ["Intervals", "Lazy", "Random", "Reexport", "TestItemRunner", "TestItems"] +git-tree-sha1 = "21fb4c221aca131474a886a015a3cd5b1a42b6d2" +uuid = "c18a7f1d-76ad-4ce4-950d-5419b888513b" +version = "0.2.5" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.3.2" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.RecipesBase]] +deps = ["PrecompileTools"] +git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.3.4" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.Revise]] +deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "REPL", "Requires", "UUIDs", "Unicode"] +git-tree-sha1 = "7b7850bb94f75762d567834d7e9802fc22d62f9c" +uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" +version = "3.5.18" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.1" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.5" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.3" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.11.0" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TZJData]] +deps = ["Artifacts"] +git-tree-sha1 = "1607ad46cf8d642aa779a1d45af1c8620dbf6915" +uuid = "dc5dba14-91b3-4cab-a142-028a31da12f7" +version = "1.2.0+2024a" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.12.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TestItemRunner]] +deps = ["Pkg", "TOML", "Test", "TestItems", "UUIDs"] +git-tree-sha1 = "29647c5398be04a1d697265ba385bdf3f623c993" +uuid = "f8b46487-2199-4994-9208-9a1283c18c0a" +version = "1.0.5" [[deps.TestItems]] git-tree-sha1 = "42fd9023fef18b9b78c8343a4e2f3813ffbcefcb" uuid = "1c621080-faea-4a02-84b6-bbd5e436b8fe" version = "1.0.0" + +[[deps.TimeZones]] +deps = ["Dates", "Downloads", "InlineStrings", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"] +git-tree-sha1 = "b92aebdd3555f3a7e3267cf17702033c2814ef48" +uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53" +version = "1.18.0" +weakdeps = ["RecipesBase"] + + [deps.TimeZones.extensions] + TimeZonesRecipesBaseExt = "RecipesBase" + +[[deps.Tokenize]] +git-tree-sha1 = "468b4685af4abe0e9fd4d7bf495a6554a6276e75" +uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" +version = "0.5.29" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "e84b3a11b9bece70d14cce63406bbc79ed3464d2" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.2" + +[[deps.URIs]] +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.1" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.11.0+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index 08c28ed..251c63e 100644 --- a/Project.toml +++ b/Project.toml @@ -4,15 +4,19 @@ authors = ["Jean-François BAFFIER (@Azzaare)"] version = "0.0.1" [deps] +Constraints = "30f324ab-b02d-43f0-b619-e131c61659f7" +HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" [compat] -Aqua = "0.8" -JET = "0.9" +Aqua = "0" +JET = "0" +Test = "1" +TestItemRunner = "1" +TestItems = "1" julia = "1.10" -Test = "1.10" -TestItems = "1.0.0" -TestItemRunner = "1.0.0" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/README.md b/README.md index 2a0365b..1d8faea 100644 --- a/README.md +++ b/README.md @@ -3,3 +3,6 @@ [![Build Status](https://github.com/Azzaare/ConstraintsTranslator.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/Azzaare/ConstraintsTranslator.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/Azzaare/ConstraintsTranslator.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/Azzaare/ConstraintsTranslator.jl) [![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) + +A package for translating natural language into Constraint Programming models for `LocalSearchSolvers.jl`. + diff --git a/src/llm.jl b/src/llm.jl new file mode 100644 index 0000000..3bab7e8 --- /dev/null +++ b/src/llm.jl @@ -0,0 +1,170 @@ +const GROQ_URL::String = "https://api.groq.com/openai/v1/chat/completions" +const GEMINI_URL::String = "https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}:generateContent" +const GEMINI_URL_STREAM::String = "https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}:streamGenerateContent?alt=sse" + +abstract type AbstractLLM end + +""" + GroqLLM +Structure encapsulating the parameters for accessing the Groq LLM API. +- `api_key`: an API key for accessing the Groq API (https://groq.com), read from the environmental variable GROQ_API_KEY +- `model_id`: a string identifier for the model to query. See https://console.groq.com/docs/models for the list of available models. +""" +struct GroqLLM <: AbstractLLM + api_key::String + model_id::String + + function GroqLLM(model_id::String) + api_key = get(ENV, "GROQ_API_KEY", "") + if isempty(api_key) + error("Environment variable GROQ_API_KEY is not set") + end + new(api_key, model_id) + end +end + +""" + Google LLM +Structure encapsulating the parameters for accessing the Google LLM API. +- `api_key`: an API key for accessing the Google Gemini API (https://ai.google.dev/gemini-api/docs/), read from the environmental variable GOOGLE_API_KEY +- `model_id`: a string identifier for the model to query. See https://ai.google.dev/gemini-api/docs/models/gemini for the list of available models. +""" +struct GoogleLLM <: AbstractLLM + api_key::String + model_id::String + + function GoogleLLM(model_id::String) + api_key = get(ENV, "GOOGLE_API_KEY", "") + if isempty(api_key) + error("Environment variable GOOGLE_API_KEY is not set") + end + new(api_key, model_id) + end +end + +""" + get_completion(llm::GroqLLM, prompt::Prompt) +Returns a completion for the given prompt using the Groq LLM API. +""" +function get_completion(llm::GroqLLM, prompt::Prompt) + headers = [ + "Authorization" => "Bearer $(llm.api_key)", + "Content-Type" => "application/json", + ] + body = JSON3.write(Dict( + "messages" => [ + Dict("role" => "system", "content" => prompt.system), + Dict("role" => "user", "content" => prompt.user), + ], + "model" => llm.model_id, + )) + response = HTTP.post(GROQ_URL, headers, body) + body = JSON3.read(response.body) + return body["choices"][1]["message"]["content"] +end + +""" + get_completion(llm::GoogleLLM, prompt::Prompt) +Returns a completion for the given prompt using the Google Gemini LLM API. +""" +function get_completion(llm::GoogleLLM, prompt::Prompt) + url = replace(GEMINI_URL, "{{model_id}}" => llm.model_id) + headers = [ + "x-goog-api-key" => "$(llm.api_key)", + "Content-Type" => "application/json", + ] + body = JSON3.write(Dict( + "contents" => Dict( + "parts" => Dict("text" => prompt.system * prompt.user) + ), + )) + response = HTTP.post(url, headers, body) + body = JSON3.read(response.body) + return body["candidates"][1]["content"]["parts"][1]["text"] +end + +""" + stream_completion(llm::GroqLLM, prompt::Prompt) +Returns a completion for the given prompt using the Groq LLM API. +The completion is streamed as it is generated and printed to the terminal. +""" +function stream_completion(llm::GroqLLM, prompt::Prompt) + headers = [ + "Authorization" => "Bearer $(llm.api_key)", + "Content-Type" => "application/json", + ] + body = JSON3.write(Dict( + "messages" => [ + Dict("role" => "system", "content" => prompt.system), + Dict("role" => "user", "content" => prompt.user), + ], + "model" => llm.model_id, + "stream" => true, + )) + + accumulated_content = "" + event_buffer = "" + + HTTP.open(:POST, GROQ_URL, headers; body = body) do io + write(io, body) + while !eof(io) + chunk = String(readavailable(io)) + events = split(chunk, "\n\n") + if !endswith(event_buffer, "\n\n") + event_buffer = events[end] + events = events[1:(end - 1)] + else + event_buffer = "" + end + events = join(events, "\n") + for line in eachmatch(r"(?<=data: ).*", events, overlap = true) + if line.match == "[DONE]" + print("\n") + break + end + message = JSON3.read(line.match) + if !isempty(message["choices"][1]["delta"]) + print(message["choices"][1]["delta"]["content"]) + accumulated_content *= message["choices"][1]["delta"]["content"] + end + end + end + end + return accumulated_content +end + +""" + stream_completion(llm::GoogleLLM, prompt::Prompt) +Returns a completion for the given prompt using the Google Gemini LLM API. +The completion is streamed as it is generated and printed to the terminal. +""" +function stream_completion(llm::GoogleLLM, prompt::Prompt) + url = replace(GEMINI_URL_STREAM, "{{model_id}}" => llm.model_id) + headers = [ + "x-goog-api-key" => "$(llm.api_key)", + "Content-Type" => "application/json", + ] + body = JSON3.write(Dict( + "contents" => Dict( + "parts" => Dict("text" => prompt.system * prompt.user) + ), + )) + + accumulated_content = "" + + HTTP.open(:POST, url, headers; body = body) do io + write(io, body) + while !eof(io) + chunk = String(readavailable(io)) + line = match(r"(?<=data: ).*", chunk) + if isnothing(line) + print("\n") + break + end + message = JSON3.read(line.match) + print(message["candidates"][1]["content"]["parts"][1]["text"]) + accumulated_content *= String(message["candidates"][1]["content"]["parts"][1]["text"]) + end + end + return accumulated_content +end \ No newline at end of file diff --git a/src/prompt.jl b/src/prompt.jl new file mode 100644 index 0000000..e15dd92 --- /dev/null +++ b/src/prompt.jl @@ -0,0 +1,6 @@ +abstract type AbstractPrompt end + +struct Prompt <: AbstractPrompt + system::String + user::String +end diff --git a/test/Aqua.jl b/test/Aqua.jl new file mode 100644 index 0000000..d60f23a --- /dev/null +++ b/test/Aqua.jl @@ -0,0 +1,3 @@ +@testset "Code quality (Aqua.jl)" begin + Aqua.test_all(ConstraintsTranslator) +end \ No newline at end of file diff --git a/test/JET.jl b/test/JET.jl new file mode 100644 index 0000000..a030a2a --- /dev/null +++ b/test/JET.jl @@ -0,0 +1,3 @@ +@testset "Code linting (JET.jl)" begin + JET.test_package(ConstraintsTranslator; target_defined_modules = true) +end \ No newline at end of file diff --git a/test/TestItemRunner.jl b/test/TestItemRunner.jl new file mode 100644 index 0000000..4cda7e5 --- /dev/null +++ b/test/TestItemRunner.jl @@ -0,0 +1,3 @@ +@testset "TestItemRunner" begin + @run_package_tests +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index bcd6fee..b735084 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,14 +1,13 @@ -using ConstraintsTranslator -using Test using Aqua +using ConstraintsTranslator using JET +using Test +using TestItemRunner +using TestItems @testset "ConstraintsTranslator.jl" begin - @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(ConstraintsTranslator) - end - @testset "Code linting (JET.jl)" begin - JET.test_package(ConstraintsTranslator; target_defined_modules = true) - end - # Write your tests here. + include("Aqua.jl") + include("JET.jl") + include("TestItemRunner.jl") + include("utils.jl") end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..fed83ef --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,33 @@ +@testset "parse_code tests" begin + input = """ + Here is some text. + ```python + code block 1 + ``` + More text. + ```julia + code block 2 + ``` + Even more text. + ``` + code block 3 + ``` + """ + + expected_output = Dict( + "python" => "code block 1", + "julia" => "code block 2", + "plain" => "code block 3", + ) + + result = parse_code(input) + + @test haskey(result, "python") + @test strip(result["python"]) == strip(expected_output["python"]) + + @test haskey(result, "julia") + @test strip(result["julia"]) == strip(expected_output["julia"]) + + @test haskey(result, "plain") + @test strip(result["plain"]) == strip(expected_output["plain"]) +end \ No newline at end of file From 53b7f95ce16f60717b6d9d4ab5885420351334d5 Mon Sep 17 00:00:00 2001 From: Nicola Di Cicco <93935338+nicoladicicco@users.noreply.github.com> Date: Mon, 2 Sep 2024 14:30:52 +0900 Subject: [PATCH 03/12] First evaluation dataset --- dataset/prompts/abstract_knapsack.txt | 16 ++++++ dataset/prompts/calendar_scheduling.txt | 17 ++++++ dataset/prompts/cargo_loading_2d.txt | 13 +++++ dataset/prompts/cargo_loading_3d.txt | 13 +++++ dataset/prompts/constrained_shortest_path.txt | 20 +++++++ dataset/prompts/cutting_stock.txt | 12 ++++ dataset/prompts/frequency_assignment.txt | 33 +++++++++++ dataset/prompts/golomb.txt | 8 +++ dataset/prompts/job_shop_scheduling.txt | 19 +++++++ dataset/prompts/knapsack.txt | 16 ++++++ dataset/prompts/marriage_seats.txt | 23 ++++++++ dataset/prompts/n_queens.txt | 2 + dataset/prompts/nurse_rostering.txt | 56 +++++++++++++++++++ dataset/prompts/sudoku.txt | 16 ++++++ dataset/prompts/traveling_salesman.txt | 24 ++++++++ dataset/prompts/university_timetabling.txt | 40 +++++++++++++ dataset/prompts/vehicle_routing.txt | 29 ++++++++++ .../prompts/vehicle_routing_time_windows.txt | 30 ++++++++++ 18 files changed, 387 insertions(+) create mode 100644 dataset/prompts/abstract_knapsack.txt create mode 100644 dataset/prompts/calendar_scheduling.txt create mode 100644 dataset/prompts/cargo_loading_2d.txt create mode 100644 dataset/prompts/cargo_loading_3d.txt create mode 100644 dataset/prompts/constrained_shortest_path.txt create mode 100644 dataset/prompts/cutting_stock.txt create mode 100644 dataset/prompts/frequency_assignment.txt create mode 100644 dataset/prompts/golomb.txt create mode 100644 dataset/prompts/job_shop_scheduling.txt create mode 100644 dataset/prompts/knapsack.txt create mode 100644 dataset/prompts/marriage_seats.txt create mode 100644 dataset/prompts/n_queens.txt create mode 100644 dataset/prompts/nurse_rostering.txt create mode 100644 dataset/prompts/sudoku.txt create mode 100644 dataset/prompts/traveling_salesman.txt create mode 100644 dataset/prompts/university_timetabling.txt create mode 100644 dataset/prompts/vehicle_routing.txt create mode 100644 dataset/prompts/vehicle_routing_time_windows.txt diff --git a/dataset/prompts/abstract_knapsack.txt b/dataset/prompts/abstract_knapsack.txt new file mode 100644 index 0000000..7e3532f --- /dev/null +++ b/dataset/prompts/abstract_knapsack.txt @@ -0,0 +1,16 @@ +I am planning a vacation and need to pack my suitcase, which has a strict weight limit. +I have several items to choose from, each with its own weight and level of importance. +The goal is to select the combination of items that will ensure the best possible vacation experience while staying within the allowed weight limit. + +Example input data: +1. items.csv +item_id,item_name,weight,importance +1,ski_combination,7,low +2,warm_clothes,4,normal +3,hiking_boots,3,high +4,hiking_book,1,high +5,umbrella,2,normal + +2. weight.csv +weight_limit +10 \ No newline at end of file diff --git a/dataset/prompts/calendar_scheduling.txt b/dataset/prompts/calendar_scheduling.txt new file mode 100644 index 0000000..20ca905 --- /dev/null +++ b/dataset/prompts/calendar_scheduling.txt @@ -0,0 +1,17 @@ +We are tasked with scheduling a set of meetings within a specific time frame, ensuring that no meetings overlap and all required participants can attend each meeting. +The objective is to find a feasible schedule that accommodates the availability of all participants and fits within the given time constraints. + +Example input data: +1. meetings.csv +meeting_id,duration,participants +1,1,John;Alice +2,2,John;Bob +3,1,Alice;Charlie +4,1,Bob;Charlie + +2. availability.csv +participant_id,availability_start,availability_end +John,09:00,12:00 +Alice,10:00,13:00 +Bob,09:00,11:00 +Charlie,11:00,14:00 \ No newline at end of file diff --git a/dataset/prompts/cargo_loading_2d.txt b/dataset/prompts/cargo_loading_2d.txt new file mode 100644 index 0000000..2382527 --- /dev/null +++ b/dataset/prompts/cargo_loading_2d.txt @@ -0,0 +1,13 @@ +We are working for a logistics company that handles cargo shipping in containers. +Each container has a fixed width of 2.5 meters and a height of 2.5 meters. +The company receives orders to ship various items, each with specific dimensions. +The task is to determine how to load the items into the containers to minimize the number of containers used, while ensuring that no items are rotated and all items fit within the container’s dimensions. + +Example input data: +1. items.csv +item_id,width,height,quantity +I1,1.2,0.5,10 +I2,2.0,0.8,5 +I3,0.5,0.5,20 +I4,1.8,1.5,3 +I5,1.0,1.0,15 \ No newline at end of file diff --git a/dataset/prompts/cargo_loading_3d.txt b/dataset/prompts/cargo_loading_3d.txt new file mode 100644 index 0000000..a86fd5d --- /dev/null +++ b/dataset/prompts/cargo_loading_3d.txt @@ -0,0 +1,13 @@ +We are working for a logistics company that handles cargo shipping in containers. +Each container has fixed dimensions: 2.5 meters in width, 2.5 meters in height, and 6 meters in length. +The company receives orders to ship various items, each with specific dimensions. +We want to determine how to load the items into the containers to minimize the number of containers used, while ensuring that no items are rotated and all items fit within the container's dimensions. + +Example input data: +1. items.csv +item_id,width,height,length,quantity +I1,1.2,0.5,3.0,10 +I2,2.0,0.8,1.5,5 +I3,0.5,0.5,0.5,20 +I4,1.8,1.5,4.0,3 +I5,1.0,1.0,2.0,15 \ No newline at end of file diff --git a/dataset/prompts/constrained_shortest_path.txt b/dataset/prompts/constrained_shortest_path.txt new file mode 100644 index 0000000..11703e3 --- /dev/null +++ b/dataset/prompts/constrained_shortest_path.txt @@ -0,0 +1,20 @@ +We need to find the shortest path, in terms of the number of hops, between a given source and destination in a capacitated graph. +Each link in the graph has a physical length and a capacity. +The objective is to find the path that minimizes the number of hops while satisfying constraints on the path's capacity (which is the minimum edge capacity along the path) and the total path length. + +Example input data: +1. graph.csv +link_id,source_node,destination_node,capacity,length +1,NodeA,NodeB,10,5 +2,NodeB,NodeC,15,7 +3,NodeA,NodeC,8,12 +4,NodeC,NodeD,12,3 +5,NodeB,NodeD,9,4 + +2. source_destination.csv +source_node,destination_node +NodeA,NodeD + +3. constraints.csv +min_path_capacity,max_path_length +9,15 diff --git a/dataset/prompts/cutting_stock.txt b/dataset/prompts/cutting_stock.txt new file mode 100644 index 0000000..4ff163b --- /dev/null +++ b/dataset/prompts/cutting_stock.txt @@ -0,0 +1,12 @@ +We have a paper roll manufacturing company that produces standard rolls of paper, all of the same width but with a fixed length of 100 meters. +The company receives orders from customers for smaller rolls of different lengths. +Our task is to determine how to cut the standard rolls to fulfill these orders while minimizing the number of standard rolls used. + +Example input data: +1. orders.csv +order_id,length_required,quantity +O1,30,5 +O2,45,3 +O3,65,2 +O4,50,4 +O5,80,1 \ No newline at end of file diff --git a/dataset/prompts/frequency_assignment.txt b/dataset/prompts/frequency_assignment.txt new file mode 100644 index 0000000..1a19e55 --- /dev/null +++ b/dataset/prompts/frequency_assignment.txt @@ -0,0 +1,33 @@ +We need to assign radio frequencies to a set of transmitters in a telecommunication network. +Each transmitter must be assigned a frequency from a given set of available frequencies. +The objective is to minimize interference between transmitters while using the minimum number of distinct frequencies. +The interference between two transmitters is proportional to the square of their geographical distance and the absolute difference between their assigned frequencies. +Each transmitter must be assigned exactly one frequency. +The frequency assigned to a transmitter must be within its allowed frequency range. +Transmitters that are geographically close to each other must have a minimum frequency separation to avoid interference. +Some transmitters may have pre-assigned frequencies that cannot be changed. Pre_assigned_frequency of -1 means no pre-assignment. + +Example input data: +1. transmitters.csv +transmitter_id,x_coordinate,y_coordinate,min_frequency,max_frequency,pre_assigned_frequency +T1,10,20,1,10,-1 +T2,15,25,1,10,-1 +T3,30,40,1,15,-1 +T4,35,45,1,15,7 +T5,50,60,5,20,-1 + +available_frequencies.csv +frequency_id,frequency_value +F1,1 +F2,2 +F3,3 +F4,4 +F5,5 + +3. interference_matrix.csv +transmitter1_id,transmitter2_id,min_frequency_separation,interference_cost +T1,T2,2,10 +T1,T3,1,5 +T1,T4,1,3 +T1,T5,0,1 +T2,T3,2,8 \ No newline at end of file diff --git a/dataset/prompts/golomb.txt b/dataset/prompts/golomb.txt new file mode 100644 index 0000000..d64e3f9 --- /dev/null +++ b/dataset/prompts/golomb.txt @@ -0,0 +1,8 @@ +We need to find a feasible Golomb ruler of a specified length m and order n. +A Golomb ruler is a set of n marks placed along a ruler such that all pairwise distances between marks are distinct. +The goal is to determine the positions of the marks that satisfy the distinct distance condition for the given length. + +Example input data: +1. input.csv +ruler_length,number_of_marks +10,4 \ No newline at end of file diff --git a/dataset/prompts/job_shop_scheduling.txt b/dataset/prompts/job_shop_scheduling.txt new file mode 100644 index 0000000..533f080 --- /dev/null +++ b/dataset/prompts/job_shop_scheduling.txt @@ -0,0 +1,19 @@ +We need to schedule a set of jobs on 4 machines, where each job consists of a sequence of tasks. +Each task must be processed on a specific machine for a given duration, and tasks within a job must follow a predefined order (dependency graph). +The objective is to find a feasible schedule that minimizes the overall completion time (makespan) while respecting the task dependencies and machine availability. + +Example input data: +1. input.txt +task_id,job_id,machine_id,processing_time,dependencies +1,1,1,3, +2,1,2,2,1 +3,1,3,4,2 +4,2,2,5, +5,2,1,3,4 +6,2,4,2,5 +7,3,3,6, +8,3,2,1,7 +9,3,4,4,8 +10,4,1,2, +11,4,3,3,10 +12,4,4,1,11 \ No newline at end of file diff --git a/dataset/prompts/knapsack.txt b/dataset/prompts/knapsack.txt new file mode 100644 index 0000000..641c151 --- /dev/null +++ b/dataset/prompts/knapsack.txt @@ -0,0 +1,16 @@ +We aim to solve a Knapsack problem where a set of items is given. +For each item, we define a binary decision variable to indicate whether the item is included in the knapsack. +The objective is to maximize the total utility of the selected items without exceeding a given weight limit. + +Example input data: +1. items.csv +item_id,weight,utility +1,2,2 +2,3,3 +3,7,1 +4,4,2 +5,1,3 + +2. weight.csv +weight_limit +10 \ No newline at end of file diff --git a/dataset/prompts/marriage_seats.txt b/dataset/prompts/marriage_seats.txt new file mode 100644 index 0000000..aebbebc --- /dev/null +++ b/dataset/prompts/marriage_seats.txt @@ -0,0 +1,23 @@ +You are tasked with creating a seating arrangement for a wedding reception. +The reception will be held in a venue with round tables, each seating 8 people. +The bride and groom must be seated at the same table (Table 1). +Immediate family members of the bride and groom must be seated at Tables 1 and 2. +Couples must be seated together. +People with known conflicts (e.g., divorced couples, family feuds) must be seated at different tables. +Maximize the number of guests seated with others they know, and seat guests with similar interests together when possible. + +Example input data: +1. guests.csv +guest_id,name,group,dietary_requirement,mobility_issue,interests +1,John Smith,Groom's Family,None,No,Sports +2,Jane Smith,Groom's Family,Vegetarian,No,Art +3,Alice Johnson,Bride's Family,Nut Allergy,Yes,Music +4,Bob Johnson,Bride's Family,None,No,Travel + +2. relationships.csv +guest_id1,guest_id2,relationship +1,2,Couple +3,4,Couple +5,6,Conflict +7,8,Strangers +1,6,Friends \ No newline at end of file diff --git a/dataset/prompts/n_queens.txt b/dataset/prompts/n_queens.txt new file mode 100644 index 0000000..b32b4bf --- /dev/null +++ b/dataset/prompts/n_queens.txt @@ -0,0 +1,2 @@ +I want to solve an n-queen puzzle where n is a positive integer. +The n-queen puzzle is the problem of placing n queens on an n x n chessboard such that no two queens can attack each other. diff --git a/dataset/prompts/nurse_rostering.txt b/dataset/prompts/nurse_rostering.txt new file mode 100644 index 0000000..a0ec88c --- /dev/null +++ b/dataset/prompts/nurse_rostering.txt @@ -0,0 +1,56 @@ +We need to create a weekly schedule for nurses in a hospital. +Each nurse must be assigned shifts for a 7-day week, with three shift types: morning (7:00-15:00), evening (15:00-23:00), and night (23:00-7:00). +Each shift must be covered by the required number of nurses with appropriate skills. +Nurses cannot be assigned to more than one shift per day. +Each nurse must have at least 11 hours of rest between shifts. +Each nurse should work between 30 and 40 hours per week. +Each nurse should have at least one weekend day (Saturday or Sunday) off every two weeks. +The number of night shifts for each nurse should be distributed fairly. +The goal is to find a feasible schedule while maximizing fairness in satisfying nurse preferences. + +Example input data: + +1. nurses.csv +nurse_id,name,skill_level,max_shifts_per_week,max_night_shifts_per_week +N1,Alice,senior,5,2 +N2,Bob,junior,5,2 +N3,Charlie,senior,4,1 +N4,Diana,mid,5,2 +N5,Eve,junior,4,1 + +2. shift_requirements.csv +day,shift_type,required_seniors,required_mid,required_juniors +Monday,morning,1,1,1 +Monday,evening,1,1,1 +Monday,night,1,0,1 +Tuesday,morning,1,1,1 +Tuesday,evening,1,1,1 +Tuesday,night,1,0,1 +Wednesday,morning,1,1,1 +Wednesday,evening,1,1,1 +Wednesday,night,1,0,1 +Thursday,morning,1,1,1 +Thursday,evening,1,1,1 +Thursday,night,1,0,1 +Friday,morning,1,1,1 +Friday,evening,1,1,1 +Friday,night,1,0,1 +Saturday,morning,1,1,1 +Saturday,evening,1,0,1 +Saturday,night,1,0,1 +Sunday,morning,1,1,1 +Sunday,evening,1,0,1 +Sunday,night,1,0,1 + +3. nurse_preferences.csv +nurse_id,day,shift_type,preference_score +N1,Monday,morning,3 +N1,Monday,evening,1 +N1,Monday,night,0 +N2,Monday,morning,2 +N2,Monday,evening,2 +N2,Monday,night,1 +... +N5,Sunday,morning,3 +N5,Sunday,evening,2 +N5,Sunday,night,0 \ No newline at end of file diff --git a/dataset/prompts/sudoku.txt b/dataset/prompts/sudoku.txt new file mode 100644 index 0000000..23d5426 --- /dev/null +++ b/dataset/prompts/sudoku.txt @@ -0,0 +1,16 @@ +We are tasked with solving a 9x9 Sudoku puzzle. +The puzzle is represented by a 9x9 grid, where each cell contains a number between 1 and 9 or is left blank. +The objective is to fill in the blank cells so that each row, column, and 3x3 subgrid contains all numbers from 1 to 9 exactly once. + +Example input data: +1. initial_values.csv +,row1,row2,row3,row4,row5,row6,row7,row8,row9 +col1,5,3,0,0,7,0,0,0,0 +col2,6,0,0,1,9,5,0,0,0 +col3,0,9,8,0,0,0,0,6,0 +col4,8,0,0,0,6,0,0,0,3 +col5,4,0,0,8,0,3,0,0,1 +col6,7,0,0,0,2,0,0,0,6 +col7,0,6,0,0,0,0,2,8,0 +col8,0,0,0,4,1,9,0,0,5 +col9,0,0,0,0,8,0,0,7,9 \ No newline at end of file diff --git a/dataset/prompts/traveling_salesman.txt b/dataset/prompts/traveling_salesman.txt new file mode 100644 index 0000000..3f21ee4 --- /dev/null +++ b/dataset/prompts/traveling_salesman.txt @@ -0,0 +1,24 @@ +We need to determine the shortest possible route for a salesman who must visit a set of cities exactly once and return to the starting city. +The objective is to minimize the total travel distance while ensuring that each city is visited exactly once. + +Example input data: +1. cities.csv +city_id,city_name +1,CityA +2,CityB +3,CityC +4,CityD +5,CityE + +2. distances.csv +from,to,distance +CityA,CityB,10 +CityA,CityC,8 +CityA,CityD,15 +CityA,CityE,12 +CityB,CityC,5 +CityB,CityD,9 +CityB,CityE,7 +CityC,CityD,6 +CityC,CityE,4 +CityD,CityE,3 diff --git a/dataset/prompts/university_timetabling.txt b/dataset/prompts/university_timetabling.txt new file mode 100644 index 0000000..980159b --- /dev/null +++ b/dataset/prompts/university_timetabling.txt @@ -0,0 +1,40 @@ +We need to create a weekly timetable for university courses. +Each course must be assigned specific time slots and a room for a 5-day week (Monday to Friday), with time slots from 9:00 to 17:00 in 1-hour increments. +The timetable must satisfy the following rules: +1. No teacher can teach more than one course simultaneously. +2. No student group can attend more than one course simultaneously. +3. Each room can host only one course at a time. +4. The assigned room must have sufficient capacity for the course's enrolled students. +5. Courses requiring special equipment must be assigned to rooms with that equipment. +6. Each course must be scheduled for its required number of hours per week. +7. Courses must be scheduled within the defined working hours (9:00 to 17:00, Monday to Friday). +The goal is to find a feasible schedule that satisfies all these constraints while minimizing the number of idle periods for both teachers and student groups. + +Example input data: + +1. courses.csv +course_id,name,teacher_id,student_group,enrolled_students,hours_per_week,required_equipment +CS101,Intro to Programming,T1,G1,30,3,computers +MATH201,Calculus,T2,G1,25,4,none +PHYS101,Physics I,T3,G2,20,3,lab +ENG201,Literature,T4,G2,15,2,none + +2. teachers.csv +teacher_id,name,max_hours_per_day +T1,Dr. Smith,6 +T2,Prof. Johnson,4 +T3,Dr. Brown,5 +T4,Dr. Davis,3 + +3. rooms.csv +room_id,capacity,equipment +R101,35,none +R102,30,computers +R103,25,lab +R104,40,none + +4. time_slots.csv +slot_id,day,start_time,end_time +1,Monday,09:00,10:00 +2,Monday,10:00,11:00 +3,Monday,11:00,12:00 diff --git a/dataset/prompts/vehicle_routing.txt b/dataset/prompts/vehicle_routing.txt new file mode 100644 index 0000000..a2bce55 --- /dev/null +++ b/dataset/prompts/vehicle_routing.txt @@ -0,0 +1,29 @@ +We need to determine the optimal set of routes for a fleet of vehicles to deliver goods to a set of customers. +Each route must stard and end at the depot. +The objective is to minimize the total travel distance while ensuring all deliveries are made without exceeding vehicle capacities. + +Example input data: +1. vehicles.csv +vehicle_id,capacity +1,15 +2,15 + +2. customers.csv +customer_id,location,demand +1,LocA,5 +2,LocB,7 +3,LocC,4 +4,LocD,6 + +3. distances.csv +from,to,distance +Depot,LocA,5 +Depot,LocB,10 +Depot,LocC,7 +Depot,LocD,3 +LocA,LocB,6 +LocA,LocC,8 +LocA,LocD,4 +LocB,LocC,2 +LocB,LocD,5 +LocC,LocD,3 \ No newline at end of file diff --git a/dataset/prompts/vehicle_routing_time_windows.txt b/dataset/prompts/vehicle_routing_time_windows.txt new file mode 100644 index 0000000..f0c9348 --- /dev/null +++ b/dataset/prompts/vehicle_routing_time_windows.txt @@ -0,0 +1,30 @@ +We need to determine the optimal set of routes for a fleet of vehicles to deliver goods to a set of customers. +Each route must stard and end at the depot. +Each customer has a specific time window during which the delivery must occur, and each vehicle has a maximum capacity. +The objective is to minimize the total travel distance while ensuring all deliveries are made within the specified time windows and without exceeding vehicle capacities. + +Example input data: +1. vehicles.csv +vehicle_id,capacity +1,15 +2,15 + +2. customers.csv +customer_id,location,demand,time_window_start,time_window_end +1,LocA,5,08:00,10:00 +2,LocB,7,09:00,11:00 +3,LocC,4,10:00,12:00 +4,LocD,6,07:00,09:00 + +3. distances.csv +from,to,distance +Depot,LocA,5 +Depot,LocB,10 +Depot,LocC,7 +Depot,LocD,3 +LocA,LocB,6 +LocA,LocC,8 +LocA,LocD,4 +LocB,LocC,2 +LocB,LocD,5 +LocC,LocD,3 \ No newline at end of file From ad5d53b5c51b9c123a6952af8d7c04dae57ed750 Mon Sep 17 00:00:00 2001 From: Nicola Di Cicco <93935338+nicoladicicco@users.noreply.github.com> Date: Tue, 3 Sep 2024 17:09:22 +0900 Subject: [PATCH 04/12] Basic templating and prompt structure --- src/prompt.jl | 6 ++ src/template.jl | 159 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 src/template.jl diff --git a/src/prompt.jl b/src/prompt.jl index e15dd92..f3062ec 100644 --- a/src/prompt.jl +++ b/src/prompt.jl @@ -1,5 +1,11 @@ abstract type AbstractPrompt end +""" + Prompt +Simple data structure encapsulating a system prompt and a user prompt for LLM generation. +- `system`: the system prompt. +- `user`: the user prompt. +""" struct Prompt <: AbstractPrompt system::String user::String diff --git a/src/template.jl b/src/template.jl new file mode 100644 index 0000000..01f709a --- /dev/null +++ b/src/template.jl @@ -0,0 +1,159 @@ +abstract type AbstractMetadata end +abstract type AbstractTemplate end + +""" + MetadataMessage + +Represents the metadata information of a prompt template. +The templates follow the specifications of `PromptingTools.jl`. + +# Fields +- `content::String`: The content of the metadata message. +- `description::String`: A description of the metadata message. +- `version::String`: The version of the metadata message. +""" +struct MetadataMessage <: AbstractMessage + content::String + description::String + version::String +end + +""" +Represents the prompt template of a system message. +The template can optionally contain string placeholders enclosed in double curly braces, e.g., `{{variable}}`. +Placeholders must be replaced with actual values when generating prompts. + +# Fields +- `content::String`: The content of the system message. +- `variables::Vector{String}`: A list of variables used in the system message. +""" +struct SystemMessage <: AbstractMessage + content::String + variables::Vector{String} +end + +""" +Represents the prompt template of a user message. +The template can optionally contain string placeholders enclosed in double curly braces, e.g., `{{variable}}`. +Placeholders must be replaced with actual values when generating prompts. + +# Fields +- `content::String`: The content of the system message. +- `variables::Vector{String}`: A list of variables used in the system message. +""" +struct UserMessage <: AbstractMessage + content::String + variables::Vector{String} +end + +""" +Represents a complete prompt template, comprising metadata, system, and user messages. + +# Fields +- `metadata::MetadataMessage`: The metadata message of the prompt template. +- `system::SystemMessage`: The system message of the prompt template. +- `user::UserMessage`: The user message of the prompt template. +""" +struct PromptTemplate <: AbstractTemplate + metadata::MetadataMessage + system::SystemMessage + user::UserMessage +end + +""" + read_template(data_path::String) + +Reads a prompt template from a JSON file specified by `data_path`. +The function parses the JSON data and constructs a `PromptTemplate` object containing metadata, system, and user messages. +TODO: validate the JSON data against a schema to ensure it is valid before parsing. + +# Arguments +- `data_path::String`: The path to the JSON file containing the prompt template. + +# Returns +- `PromptTemplate`: A `PromptTemplate` structure encapsulating the metadata, system, and user messages. +""" +function read_template(data_path::String)::PromptTemplate + file_content = read(data_path, String) + data = JSON.parse(file_content) + + metadata = nothing + system = nothing + user = nothing + + for item in data + _type = item["_type"] + if _type == "metadatamessage" + metadata = MetadataMessage( + item["content"], + item["description"], + item["version"], + ) + elseif _type == "systemmessage" + system = SystemMessage( + item["content"], + item["variables"], + ) + elseif _type == "usermessage" + user = UserMessage( + item["content"], + item["variables"], + ) + else + error("Unknown message type: $_type") + end + end + + if isnothing(metadata) || isnothing(system) || isnothing(user) + error("Template must contain metadata, system, and user messages") + end + + return PromptTemplate(metadata, system, user) +end + +""" + format_template(template::PromptTemplate; kwargs...)::Prompt + +Formats a `PromptTemplate` by substituting all variables in the system and user messages with user-provided values. + +# Arguments +- `template::PromptTemplate`: The prompt template containing metadata, system, and user messages. +- `kwargs...`: A variable number of keyword arguments where keys are variable names and values are the corresponding replacements. + +# Returns +- `Prompt`: A `Prompt` struct with the system and user messages containing the substituted values. + +# Raises +- `ArgumentError`: If any variables specified in the system or user templates are not present in the `kwargs`. +- `Warning`: If there are extra `kwargs` that are not used in the templates. +""" +function format_template(template::PromptTemplate; kwargs...) + system = template.system.content + user = template.user.content + + template_vars = union(Set(template.system.variables), Set(template.user.variables)) + kwargs_keys = Set(keys(kwargs)) + + # Check for missing variables in kwargs + missing_vars = template_vars - kwargs_keys + if !isempty(missing_vars) + error("Missing variables in kwargs: $(collect(missing_vars))") + end + + # Check for extra kwargs + extra_kwargs = kwargs_keys - template_vars + if !isempty(extra_kwargs) + @warn "Extra kwargs will be ignored: $(collect(extra_kwargs))" + end + + # Substitute variables in the system and user content + for var in template.system.variables + system = replace(system, "{{$(var)}}", kwargs[var]) + end + + for var in template.user.variables + user = replace(user, "{{$(var)}}", kwargs[var]) + end + + return Prompt(system, user) +end \ No newline at end of file From f183ffe0fe43ee224b6df1ba6ea797dc82c1e089 Mon Sep 17 00:00:00 2001 From: Nicola Di Cicco <93935338+nicoladicicco@users.noreply.github.com> Date: Tue, 3 Sep 2024 18:03:11 +0900 Subject: [PATCH 05/12] Add basic parsing, fixes --- src/ConstraintsTranslator.jl | 23 ++++++++++++++++++----- src/parsing.jl | 28 ++++++++++++++++++++++++++++ src/prompt.jl | 2 ++ src/template.jl | 7 +++++-- 4 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 src/parsing.jl diff --git a/src/ConstraintsTranslator.jl b/src/ConstraintsTranslator.jl index 82f2581..27e54b6 100644 --- a/src/ConstraintsTranslator.jl +++ b/src/ConstraintsTranslator.jl @@ -1,12 +1,25 @@ module ConstraintsTranslator -#SECTION - Imports +# Imports +import HTTP +import JSON3 import TestItems: @testitem +import Constraints: USUAL_CONSTRAINTS -#SECTION - Exports +# Exports +export parse_code +export translate +export Prompt +export PromptTemplate +export GroqLLM +export GoogleLLM +export get_completion +export stream_completion -#SECTION - Includes - -#SECTION - Main function (optional) +# Includes +include("prompt.jl") +include("template.jl") +include("llm.jl") +include("utils.jl") end diff --git a/src/parsing.jl b/src/parsing.jl new file mode 100644 index 0000000..ae821cc --- /dev/null +++ b/src/parsing.jl @@ -0,0 +1,28 @@ +""" + parse_code(s::String) +Parse the code blocks in the input string `s` delimited by triple backticks and an optional language annotation. +Returns a dictionary keyed by language. Code blocks from the same language are concatenated. +""" +function parse_code(s::String) + # Regular expression to match code blocks with optional language annotation + pattern = r"```(\w*)\n(.*?)```"s + + # Find all matches + matches = eachmatch(pattern, s) + + # Initialize a dictionary to store code blocks by language + code_dict = Dict{String, String}() + + # Extract the code blocks and their language annotations + for m in matches + lang = m.captures[1] == "" ? "plain" : m.captures[1] + code = strip(m.captures[2]) + if haskey(code_dict, lang) + code_dict[lang] *= "\n" * code + else + code_dict[lang] = code + end + end + + return code_dict +end \ No newline at end of file diff --git a/src/prompt.jl b/src/prompt.jl index f3062ec..1357381 100644 --- a/src/prompt.jl +++ b/src/prompt.jl @@ -3,6 +3,8 @@ abstract type AbstractPrompt end """ Prompt Simple data structure encapsulating a system prompt and a user prompt for LLM generation. + +## Fields - `system`: the system prompt. - `user`: the user prompt. """ diff --git a/src/template.jl b/src/template.jl index 01f709a..327f74c 100644 --- a/src/template.jl +++ b/src/template.jl @@ -1,4 +1,4 @@ -abstract type AbstractMetadata end +abstract type AbstractMessage end abstract type AbstractTemplate end """ @@ -72,6 +72,9 @@ TODO: validate the JSON data against a schema to ensure it is valid before parsi # Returns - `PromptTemplate`: A `PromptTemplate` structure encapsulating the metadata, system, and user messages. + +# Raises +- `ErrorException`: if the template does not match the specification format. """ function read_template(data_path::String)::PromptTemplate file_content = read(data_path, String) @@ -124,7 +127,7 @@ Formats a `PromptTemplate` by substituting all variables in the system and user - `Prompt`: A `Prompt` struct with the system and user messages containing the substituted values. # Raises -- `ArgumentError`: If any variables specified in the system or user templates are not present in the `kwargs`. +- `ErrorException`: If any variables specified in the system or user templates are not present in the `kwargs`. - `Warning`: If there are extra `kwargs` that are not used in the templates. """ function format_template(template::PromptTemplate; kwargs...) From ff2df2acf5ea0ee082e9688ec3ddbe92fc3ccab8 Mon Sep 17 00:00:00 2001 From: Nicola Di Cicco <93935338+nicoladicicco@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:52:49 +0900 Subject: [PATCH 06/12] Update dataset --- .../prompts/capacitated_facility_location.txt | 15 +++++++++ dataset/prompts/cargo_loading_2d.txt | 6 +--- dataset/prompts/cargo_loading_3d.txt | 4 --- dataset/prompts/job_shop_scheduling.txt | 7 ---- dataset/prompts/marriage_seats.txt | 2 +- dataset/prompts/min_cut.txt | 0 dataset/prompts/nurse_rostering.txt | 21 +----------- dataset/prompts/quadratic_assignment.txt | 32 +++++++++++++++++++ dataset/prompts/set_cover.txt | 18 +++++++++++ dataset/prompts/traveling_salesman.txt | 8 ----- 10 files changed, 68 insertions(+), 45 deletions(-) create mode 100644 dataset/prompts/capacitated_facility_location.txt create mode 100644 dataset/prompts/min_cut.txt create mode 100644 dataset/prompts/quadratic_assignment.txt create mode 100644 dataset/prompts/set_cover.txt diff --git a/dataset/prompts/capacitated_facility_location.txt b/dataset/prompts/capacitated_facility_location.txt new file mode 100644 index 0000000..e62e94a --- /dev/null +++ b/dataset/prompts/capacitated_facility_location.txt @@ -0,0 +1,15 @@ +We aim to solve a Capacitated Facility Location problem where a set of facilities and customers is given. +The objective is to minimize the total cost of opening facilities and serving customers while ensuring that each customer's demand is fully satisfied, and no facility exceeds its capacity. + +Example input data: +1. facilities.csv +facility_id,opening_cost,capacity +1,5,15 + +2. customers.csv +customer_id,demand +1,5 + +3. transport_cost.csv +facility_id,customer_id,cost +1,1,3 diff --git a/dataset/prompts/cargo_loading_2d.txt b/dataset/prompts/cargo_loading_2d.txt index 2382527..404f454 100644 --- a/dataset/prompts/cargo_loading_2d.txt +++ b/dataset/prompts/cargo_loading_2d.txt @@ -6,8 +6,4 @@ The task is to determine how to load the items into the containers to minimize t Example input data: 1. items.csv item_id,width,height,quantity -I1,1.2,0.5,10 -I2,2.0,0.8,5 -I3,0.5,0.5,20 -I4,1.8,1.5,3 -I5,1.0,1.0,15 \ No newline at end of file +I1,1.2,0.5,10 \ No newline at end of file diff --git a/dataset/prompts/cargo_loading_3d.txt b/dataset/prompts/cargo_loading_3d.txt index a86fd5d..2a5470b 100644 --- a/dataset/prompts/cargo_loading_3d.txt +++ b/dataset/prompts/cargo_loading_3d.txt @@ -7,7 +7,3 @@ Example input data: 1. items.csv item_id,width,height,length,quantity I1,1.2,0.5,3.0,10 -I2,2.0,0.8,1.5,5 -I3,0.5,0.5,0.5,20 -I4,1.8,1.5,4.0,3 -I5,1.0,1.0,2.0,15 \ No newline at end of file diff --git a/dataset/prompts/job_shop_scheduling.txt b/dataset/prompts/job_shop_scheduling.txt index 533f080..994183d 100644 --- a/dataset/prompts/job_shop_scheduling.txt +++ b/dataset/prompts/job_shop_scheduling.txt @@ -10,10 +10,3 @@ task_id,job_id,machine_id,processing_time,dependencies 3,1,3,4,2 4,2,2,5, 5,2,1,3,4 -6,2,4,2,5 -7,3,3,6, -8,3,2,1,7 -9,3,4,4,8 -10,4,1,2, -11,4,3,3,10 -12,4,4,1,11 \ No newline at end of file diff --git a/dataset/prompts/marriage_seats.txt b/dataset/prompts/marriage_seats.txt index aebbebc..eebb660 100644 --- a/dataset/prompts/marriage_seats.txt +++ b/dataset/prompts/marriage_seats.txt @@ -8,7 +8,7 @@ Maximize the number of guests seated with others they know, and seat guests with Example input data: 1. guests.csv -guest_id,name,group,dietary_requirement,mobility_issue,interests +guest_id,name,group,dietary_requirement,interests 1,John Smith,Groom's Family,None,No,Sports 2,Jane Smith,Groom's Family,Vegetarian,No,Art 3,Alice Johnson,Bride's Family,Nut Allergy,Yes,Music diff --git a/dataset/prompts/min_cut.txt b/dataset/prompts/min_cut.txt new file mode 100644 index 0000000..e69de29 diff --git a/dataset/prompts/nurse_rostering.txt b/dataset/prompts/nurse_rostering.txt index a0ec88c..b20b163 100644 --- a/dataset/prompts/nurse_rostering.txt +++ b/dataset/prompts/nurse_rostering.txt @@ -26,21 +26,6 @@ Monday,night,1,0,1 Tuesday,morning,1,1,1 Tuesday,evening,1,1,1 Tuesday,night,1,0,1 -Wednesday,morning,1,1,1 -Wednesday,evening,1,1,1 -Wednesday,night,1,0,1 -Thursday,morning,1,1,1 -Thursday,evening,1,1,1 -Thursday,night,1,0,1 -Friday,morning,1,1,1 -Friday,evening,1,1,1 -Friday,night,1,0,1 -Saturday,morning,1,1,1 -Saturday,evening,1,0,1 -Saturday,night,1,0,1 -Sunday,morning,1,1,1 -Sunday,evening,1,0,1 -Sunday,night,1,0,1 3. nurse_preferences.csv nurse_id,day,shift_type,preference_score @@ -49,8 +34,4 @@ N1,Monday,evening,1 N1,Monday,night,0 N2,Monday,morning,2 N2,Monday,evening,2 -N2,Monday,night,1 -... -N5,Sunday,morning,3 -N5,Sunday,evening,2 -N5,Sunday,night,0 \ No newline at end of file +N2,Monday,night,1 \ No newline at end of file diff --git a/dataset/prompts/quadratic_assignment.txt b/dataset/prompts/quadratic_assignment.txt new file mode 100644 index 0000000..6b5db20 --- /dev/null +++ b/dataset/prompts/quadratic_assignment.txt @@ -0,0 +1,32 @@ +We aim to assign a set of facilities to a set of locations. +The objective is to minimize the total cost, which is equal to the distance between locations time the flow between facilities. +Example input data: + +facilities.csv +facility_id +1 +2 +3 +4 +locations.csv +location_id +1 +2 +3 +4 +flow.csv +facility_id_1,facility_id_2,flow +1,2,10 +1,3,8 +1,4,12 +2,3,6 +2,4,9 +3,4,7 +distance.csv +location_id_1,location_id_2,distance +1,2,4 +1,3,7 +1,4,3 +2,3,5 +2,4,6 +3,4,2 diff --git a/dataset/prompts/set_cover.txt b/dataset/prompts/set_cover.txt new file mode 100644 index 0000000..24482cc --- /dev/null +++ b/dataset/prompts/set_cover.txt @@ -0,0 +1,18 @@ +We aim to solve a Set Covering problem where a set of locations needs to be covered by selecting the minimum number of available sets. +The objective is to minimize the number of selected sets while ensuring that every location is covered by at least one selected set. +Example input data: + +1. sets.csv +set_id,cost +1,3 +2,5 +3,2 +4,4 + +2. coverage.csv +set_id,location_id +1,1 +1,2 +2,1 +2,3 +2,4 \ No newline at end of file diff --git a/dataset/prompts/traveling_salesman.txt b/dataset/prompts/traveling_salesman.txt index 3f21ee4..924f179 100644 --- a/dataset/prompts/traveling_salesman.txt +++ b/dataset/prompts/traveling_salesman.txt @@ -14,11 +14,3 @@ city_id,city_name from,to,distance CityA,CityB,10 CityA,CityC,8 -CityA,CityD,15 -CityA,CityE,12 -CityB,CityC,5 -CityB,CityD,9 -CityB,CityE,7 -CityC,CityD,6 -CityC,CityE,4 -CityD,CityE,3 From 5b633c09abc5152d1afce7162ee20a8ce6cc7e4a Mon Sep 17 00:00:00 2001 From: Nicola Di Cicco <93935338+nicoladicicco@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:53:45 +0900 Subject: [PATCH 07/12] Stabilize interface --- src/ConstraintsTranslator.jl | 11 +++++++++-- src/llm.jl | 24 +++++++++++++++++++++--- src/parsing.jl | 14 ++++++++++++++ src/template.jl | 15 +++++++++------ src/templates/ExtractStructure.json | 23 +++++++++++++++++++++++ src/templates/JumpifyModel.json | 24 ++++++++++++++++++++++++ src/templates/README.md | 5 +++++ 7 files changed, 105 insertions(+), 11 deletions(-) create mode 100644 src/templates/ExtractStructure.json create mode 100644 src/templates/JumpifyModel.json create mode 100644 src/templates/README.md diff --git a/src/ConstraintsTranslator.jl b/src/ConstraintsTranslator.jl index 27e54b6..1618a91 100644 --- a/src/ConstraintsTranslator.jl +++ b/src/ConstraintsTranslator.jl @@ -2,24 +2,31 @@ module ConstraintsTranslator # Imports import HTTP +import JSONSchema import JSON3 import TestItems: @testitem import Constraints: USUAL_CONSTRAINTS +import REPL +using REPL.TerminalMenus # Exports +export AbstractLLM export parse_code -export translate export Prompt export PromptTemplate export GroqLLM export GoogleLLM export get_completion export stream_completion +export read_template +export format_template +export translate # Includes include("prompt.jl") include("template.jl") include("llm.jl") -include("utils.jl") +include("parsing.jl") +include("translate.jl") end diff --git a/src/llm.jl b/src/llm.jl index 3bab7e8..8446a2e 100644 --- a/src/llm.jl +++ b/src/llm.jl @@ -86,7 +86,7 @@ end """ stream_completion(llm::GroqLLM, prompt::Prompt) Returns a completion for the given prompt using the Groq LLM API. -The completion is streamed as it is generated and printed to the terminal. +The completion is streamed to the terminal as it is generated. """ function stream_completion(llm::GroqLLM, prompt::Prompt) headers = [ @@ -136,7 +136,7 @@ end """ stream_completion(llm::GoogleLLM, prompt::Prompt) Returns a completion for the given prompt using the Google Gemini LLM API. -The completion is streamed as it is generated and printed to the terminal. +The completion is streamed to the terminal as it is generated. """ function stream_completion(llm::GoogleLLM, prompt::Prompt) url = replace(GEMINI_URL_STREAM, "{{model_id}}" => llm.model_id) @@ -167,4 +167,22 @@ function stream_completion(llm::GoogleLLM, prompt::Prompt) end end return accumulated_content -end \ No newline at end of file +end + +""" + stream_completion(llm::AbstractLLM, prompt::AbstractPrompt) +Returns a completion for a `prompt` using the `llm` model API. +The completion is streamed to the terminal as it is generated. +""" +function stream_completion(llm::AbstractLLM, prompt::AbstractPrompt) + @warn "Not implemented for this LLM and/or prompt type. Falling back to get_completion." + return get_completion(llm, prompt) +end + +""" + get_completion(llm::AbstractLLM, prompt::AbstractPrompt) +Returns a completion for a `prompt` using the `llm` model API. +""" +function get_completion(llm::AbstractLLM, prompt::AbstractPrompt) + error("Not implemented for this LLM and/or prompt type.") +end diff --git a/src/parsing.jl b/src/parsing.jl index ae821cc..e322b47 100644 --- a/src/parsing.jl +++ b/src/parsing.jl @@ -25,4 +25,18 @@ function parse_code(s::String) end return code_dict +end + +""" + edit_in_vim(s::String) +Edits the input string `s` in a temporary file using the Vim editor. +Returns the modified string after the editor is closed. +""" +function edit_in_vim(initial_text::String) + temp_filename = tempname() + write(temp_filename, initial_text) + run(`vim $temp_filename`) + edited_text = read(temp_filename, String) + rm(temp_filename) + return edited_text end \ No newline at end of file diff --git a/src/template.jl b/src/template.jl index 327f74c..6c484d5 100644 --- a/src/template.jl +++ b/src/template.jl @@ -78,7 +78,7 @@ TODO: validate the JSON data against a schema to ensure it is valid before parsi """ function read_template(data_path::String)::PromptTemplate file_content = read(data_path, String) - data = JSON.parse(file_content) + data = JSON3.read(file_content) metadata = nothing system = nothing @@ -134,28 +134,31 @@ function format_template(template::PromptTemplate; kwargs...) system = template.system.content user = template.user.content - template_vars = union(Set(template.system.variables), Set(template.user.variables)) + template_vars = union( + Set(Symbol.(template.system.variables)), + Set(Symbol.(template.user.variables)), + ) kwargs_keys = Set(keys(kwargs)) # Check for missing variables in kwargs - missing_vars = template_vars - kwargs_keys + missing_vars = setdiff(template_vars, kwargs_keys) if !isempty(missing_vars) error("Missing variables in kwargs: $(collect(missing_vars))") end # Check for extra kwargs - extra_kwargs = kwargs_keys - template_vars + extra_kwargs = setdiff(kwargs_keys, template_vars) if !isempty(extra_kwargs) @warn "Extra kwargs will be ignored: $(collect(extra_kwargs))" end # Substitute variables in the system and user content for var in template.system.variables - system = replace(system, "{{$(var)}}", kwargs[var]) + system = replace(system, "{{$(var)}}" => kwargs[Symbol(var)]) end for var in template.user.variables - user = replace(user, "{{$(var)}}", kwargs[var]) + user = replace(user, "{{$(var)}}" => kwargs[Symbol(var)]) end return Prompt(system, user) diff --git a/src/templates/ExtractStructure.json b/src/templates/ExtractStructure.json new file mode 100644 index 0000000..1993195 --- /dev/null +++ b/src/templates/ExtractStructure.json @@ -0,0 +1,23 @@ +[ + { + "content": "Template Metadata", + "description": "Instructs the LLM to extract the high-level structure of the optimization problem based on the given description.", + "version": "2.0", + "source": "", + "_type": "metadatamessage" + }, + { + "content": "You are an AI assistant specialized in modeling Constraint Programming problems. You have extensive knowledge of the constraints commonly used in Constraint Programming, especially XCSP3 constraints.\nYour task is to examine a given problem description and extract key structural information. You must focus on the general form of the problem rather than specific instances or numerical values. Provide your analysis in the following format:\n\n1. Problem Description:\n- Summarize the problem statement and all of its specifications.\n\n2. Parameter Sets:\n- Identify sets of known quantities (i.e., data) given in the problem description. These are fixed inputs to the problem, not determined by the optimization process.\n- For each set of parameters:\n* Provide a descriptive name for the set.\n\n*Define a notation for the set using subscripts (e.g., a_ijk), specifying the meaning and the range of each index.\n\n3. Decision Variables:\n- Identify the key sets of decisions that need to be made. For each set of decision variables:\n* Provide a descriptive name for the set.\n* Specify the domain (possible values) for elements in this set, which can be either binary, integer or continuous.\n*Define a notation for the set using subscripts (e.g., x_ijk), specifying the meaning and the range of each index.\n\n4. Problem Type: determine whether the problem is a satisfaction or an optimization problem. If it is an optimization problem, provide a Description of the objective function and a symbolic Expression, consistently with the notation already defined. Otherwise, just concisely state that the problem is a satisfaction problem.\n\n5. Constraints. Express the problem's constraint using user-provided Core Constraints. For each constraint:\n* Provide a short textual description\n*Provide the Core Constraint enforcing the constraint. List of core constraints:\n{{constraints}}\n\nIMPORTANT: think step-by-step: a good problem formulation should be clear and concise, with the fewest possible variables and constraints. You must not refer to constraints outside the Core Constraints list. You must output the requested information only.", + "variables": [ + "constraints" + ], + "_type": "systemmessage" + }, + { + "content": "# Problem description: {{description}}", + "variables": [ + "description" + ], + "_type": "usermessage" + } +] \ No newline at end of file diff --git a/src/templates/JumpifyModel.json b/src/templates/JumpifyModel.json new file mode 100644 index 0000000..4cf4c2b --- /dev/null +++ b/src/templates/JumpifyModel.json @@ -0,0 +1,24 @@ +[ + { + "content": "Template Metadata", + "description": "Instructs the LLM to convert a structured textual description of a Constraint Programming problem into a JuMP model to be solved with LocalSearchSolvers.jl.", + "version": "2.0", + "source": "", + "_type": "metadatamessage" + }, + { + "content": "You are an AI assistant specialized in modeling Constraint Programming problems. Your task is to examine a given description of a Constraint Programming model and provide a code implementation in Julia, using JuMP and the CBLS solver.\nConstraints MUST be expressed with the following JuMP syntax: `@constraint(model, x in ConstraintName(kwargs)`, where `x` is a vector of variables, `ConstraintName` is the name of the constraint in camel-case (example: all different constraint -> AllDifferent()), and `kwargs` are the keyword arguments for the constraint (example: Sum(op=<=, val=10).\nIMPORTANT: 1. Output only code with no additional text.\n2. You must write a docstring for the code.\n3. The code must be succinct and capture all the described constraints.\n4. You MUST use the provide syntax to express constraints. Do NOT express constraints in algebraic form. \n Example output for the Magic Square Problem:\n{{example_magic_square}}\n\nExample output for the Quadratic Assignment problem:\n{{example_qap}}.", + "variables": [ + "example_magic_square", + "example_qap" + ], + "_type": "systemmessage" + }, + { + "content": "{{description}}", + "variables": [ + "description" + ], + "_type": "usermessage" + } +] \ No newline at end of file diff --git a/src/templates/README.md b/src/templates/README.md new file mode 100644 index 0000000..b0d26d5 --- /dev/null +++ b/src/templates/README.md @@ -0,0 +1,5 @@ +# Prompt templates + +This folder contains the prompt templates used in the different components of this package. + +The prompts are `.json` files in the template format defined by [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl). \ No newline at end of file From c7886c50373d0b779add3df0c222b66785da3ed2 Mon Sep 17 00:00:00 2001 From: Nicola Di Cicco <93935338+nicoladicicco@users.noreply.github.com> Date: Tue, 10 Sep 2024 17:58:00 +0900 Subject: [PATCH 08/12] Semi-functional interface --- Manifest.toml | 58 ++++-------- Project.toml | 11 ++- README.md | 47 +++++++++- dataset/prompts/quadratic_assignment.txt | 13 +-- dataset/prompts/traveling_salesman.txt | 3 - examples/README.md | 3 + examples/magic_square.md | 28 ++++++ examples/quadratic_assignment.md | 25 ++++++ src/ConstraintsTranslator.jl | 3 +- src/llm.jl | 100 ++++++++++++++++++++- src/template.jl | 17 ++-- src/templates/ExtractStructure.json | 23 ----- src/templates/JumpifyModel.json | 24 ----- src/translate.jl | 110 +++++++++++++++++++++++ templates/ExtractStructure.json | 23 +++++ templates/JumpifyModel.json | 23 +++++ {src/templates => templates}/README.md | 0 templates/TemplateSchema.json | 96 ++++++++++++++++++++ test/{utils.jl => parsing.jl} | 0 test/runtests.jl | 2 +- 20 files changed, 497 insertions(+), 112 deletions(-) create mode 100644 examples/README.md create mode 100644 examples/magic_square.md create mode 100644 examples/quadratic_assignment.md delete mode 100644 src/templates/ExtractStructure.json delete mode 100644 src/templates/JumpifyModel.json create mode 100644 src/translate.jl create mode 100644 templates/ExtractStructure.json create mode 100644 templates/JumpifyModel.json rename {src/templates => templates}/README.md (100%) create mode 100644 templates/TemplateSchema.json rename test/{utils.jl => parsing.jl} (100%) diff --git a/Manifest.toml b/Manifest.toml index 77883a6..4af0f2b 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.5" manifest_format = "2.0" -project_hash = "2ab469f27c9700c80fa8dc8e1198e45829d4137c" +project_hash = "e3750de370b9d1d33722c2ed1083743cce2dcb1e" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -25,12 +25,6 @@ git-tree-sha1 = "0157e592151e39fa570645e2b2debcdfb8a0f112" uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" version = "3.4.3" -[[deps.CodeTracking]] -deps = ["InteractiveUtils", "UUIDs"] -git-tree-sha1 = "7eee164f122511d3e4e1ebadb7956939ea7e1c77" -uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" -version = "1.3.6" - [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759" @@ -84,9 +78,9 @@ version = "0.3.13" [[deps.Constraints]] deps = ["CompositionalNetworks", "ConstraintCommons", "ConstraintDomains", "DataFrames", "Dictionaries", "MacroTools", "PrettyTables", "TestItems"] -path = "../Constraints" +git-tree-sha1 = "8256d3a55ad8e7be10fa4e18325ad39dfbd24c68" uuid = "30f324ab-b02d-43f0-b619-e131c61659f7" -version = "0.5.6" +version = "0.5.7" [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" @@ -139,10 +133,6 @@ version = "0.10.11" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - [[deps.DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" @@ -250,18 +240,18 @@ version = "1.14.0" [deps.JSON3.weakdeps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" +[[deps.JSONSchema]] +deps = ["Downloads", "JSON", "JSON3", "URIs"] +git-tree-sha1 = "243f1cdb476835d7c249deb9f29ad6b7827da7d3" +uuid = "7d188eb4-7ad8-530c-ae41-71a32a6d4692" +version = "1.4.1" + [[deps.JuliaFormatter]] deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "PrecompileTools", "TOML", "Tokenize"] git-tree-sha1 = "bb4696471330275adfd6c78c6173f276e8c067aa" uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899" version = "1.0.60" -[[deps.JuliaInterpreter]] -deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] -git-tree-sha1 = "4b415b6cccb9ab61fec78a621572c82ac7fa5776" -uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" -version = "0.9.35" - [[deps.LaTeXStrings]] git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" @@ -329,12 +319,6 @@ git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" version = "1.0.3" -[[deps.LoweredCodeUtils]] -deps = ["JuliaInterpreter"] -git-tree-sha1 = "1ce1834f9644a8f7c011eb0592b7fd6c42c90653" -uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" -version = "3.0.1" - [[deps.MacroTools]] deps = ["Markdown", "Random"] git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" @@ -392,9 +376,9 @@ version = "1.4.3" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +git-tree-sha1 = "1b35263570443fdd9e76c76b7062116e2f374ab8" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" +version = "3.0.15+0" [[deps.OrderedCollections]] git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" @@ -465,18 +449,6 @@ git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" version = "1.2.2" -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.Revise]] -deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "REPL", "Requires", "UUIDs", "Unicode"] -git-tree-sha1 = "7b7850bb94f75762d567834d7e9802fc22d62f9c" -uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" -version = "3.5.18" - [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" @@ -556,9 +528,9 @@ version = "1.0.3" [[deps.TZJData]] deps = ["Artifacts"] -git-tree-sha1 = "1607ad46cf8d642aa779a1d45af1c8620dbf6915" +git-tree-sha1 = "36b40607bf2bf856828690e097e1c799623b0602" uuid = "dc5dba14-91b3-4cab-a142-028a31da12f7" -version = "1.2.0+2024a" +version = "1.3.0+2024b" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -594,9 +566,9 @@ version = "1.0.0" [[deps.TimeZones]] deps = ["Dates", "Downloads", "InlineStrings", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"] -git-tree-sha1 = "b92aebdd3555f3a7e3267cf17702033c2814ef48" +git-tree-sha1 = "8323074bc977aa85cf5ad71099a83ac75b0ac107" uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53" -version = "1.18.0" +version = "1.18.1" weakdeps = ["RecipesBase"] [deps.TimeZones.extensions] diff --git a/Project.toml b/Project.toml index 251c63e..8e1a3ed 100644 --- a/Project.toml +++ b/Project.toml @@ -7,12 +7,17 @@ version = "0.0.1" Constraints = "30f324ab-b02d-43f0-b619-e131c61659f7" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" +JSONSchema = "7d188eb4-7ad8-530c-ae41-71a32a6d4692" +REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" [compat] -Aqua = "0" -JET = "0" +Aqua = "0.8" +Constraints = "0.5" +HTTP = "1.10" +JET = "0.9" +JSON3 = "1" +JSONSchema = "1" Test = "1" TestItemRunner = "1" TestItems = "1" diff --git a/README.md b/README.md index 1d8faea..83d290b 100644 --- a/README.md +++ b/README.md @@ -4,5 +4,50 @@ [![Coverage](https://codecov.io/gh/Azzaare/ConstraintsTranslator.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/Azzaare/ConstraintsTranslator.jl) [![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) -A package for translating natural language into Constraint Programming models for `LocalSearchSolvers.jl`. +A package for translating natural-language descriptions of optimization problems into Constraint Programming models to be solved via [`CBLS.jl`](https://github.com/JuliaConstraints/CBLS.jl) using Large Language Models (LLMs). + +This package acts as a light wrapper around common LLM API endpoints, supplying appropriate system prompts and context informations to the LLMs to generate CP models. Specifically, we first prompt the model for generating an high-level representation of the problem in editable Markdown format, and then we prompt the model to generate Julia code. + +We currently support the following LLM APIs: +- Groq (https://groq.com) +- Google Gemini (https://ai.google.dev) +- llama.cpp (https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md) + +## Why not OpenAI / Anthropic / etc.? +Groq and Gemini are currently offering rate-limited free access to their APIs, and llama.cpp is free and open-source. We are still actively experimenting with this package, and we are not in a position to pay for API access. We might consider adding support for other APIs in the future. + +## Workflow example +To begin playing with the package, you can start from the example below: + +```julia +using ConstraintsTranslator + +llm = GoogleLLM("gemini-1.5-pro") + +description = """ +We need to determine the shortest possible route for a salesman who must visit a set of cities exactly once and return to the starting city. +The objective is to minimize the total travel distance while ensuring that each city is visited exactly once. + +Example input data: +1. cities.csv +city_id,city_name +1,CityA +2,CityB + +2. distances.csv +from,to,distance +CityA,CityB,10 +CityA,CityC,8 +""" + +response = translate(llm, description) +``` + +The `translate` function will first produce a Markdown representation of the problem, and then return the generated Julia code for parsing the input data and building the model. + +This example uses Google Gemini as an LLM. You will need an API key and a model id to access proprietary API endpoints. Use `help?>` in the Julia REPL to learn more about the available models. + +At each generation step, it will prompt the user in an interactive menu to accept the answer, edit the prompt and/or the generated text, or generate another answer with the same prompt. + +The LLM expects the user to provide examples of the input data format. If no examples are present, the LLM will make assumptions about the data format based on the problem description. diff --git a/dataset/prompts/quadratic_assignment.txt b/dataset/prompts/quadratic_assignment.txt index 6b5db20..c76f0f1 100644 --- a/dataset/prompts/quadratic_assignment.txt +++ b/dataset/prompts/quadratic_assignment.txt @@ -1,20 +1,22 @@ We aim to assign a set of facilities to a set of locations. -The objective is to minimize the total cost, which is equal to the distance between locations time the flow between facilities. +The objective is to minimize the total cost, which is equal to the distance between locations times the flow between facilities. Example input data: -facilities.csv +1. facilities.csv facility_id 1 2 3 4 -locations.csv + +2. locations.csv location_id 1 2 3 4 -flow.csv + +3. flow.csv facility_id_1,facility_id_2,flow 1,2,10 1,3,8 @@ -22,7 +24,8 @@ facility_id_1,facility_id_2,flow 2,3,6 2,4,9 3,4,7 -distance.csv + +4. distance.csv location_id_1,location_id_2,distance 1,2,4 1,3,7 diff --git a/dataset/prompts/traveling_salesman.txt b/dataset/prompts/traveling_salesman.txt index 924f179..ce4f06f 100644 --- a/dataset/prompts/traveling_salesman.txt +++ b/dataset/prompts/traveling_salesman.txt @@ -6,9 +6,6 @@ Example input data: city_id,city_name 1,CityA 2,CityB -3,CityC -4,CityD -5,CityE 2. distances.csv from,to,distance diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..0709dc5 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,3 @@ +# In-context examples + +This folder contains `CBLS.jl` implementations to be used as in-context examples for the LLMs used in `ConstraintsTranslator.jl` \ No newline at end of file diff --git a/examples/magic_square.md b/examples/magic_square.md new file mode 100644 index 0000000..20e1fdf --- /dev/null +++ b/examples/magic_square.md @@ -0,0 +1,28 @@ +# Example for a Magic Square problem +```julia +""" + magic_square(n) + +Create a JuMP model for an n x n magic square. + +## Arguments +- `n::Int`: size of the magic square. +""" +function magic_square(n) + N = n^2 + model = JuMP.Model(CBLS.Optimizer) + magic_constant = n * (N + 1) / 2 + + @variable(model, 1 ≤ X[1:n, 1:n] ≤ N, Int) + @constraint(model, vec(X) in AllDifferent()) + + for i in 1:n + @constraint(model, X[i, :] in AllEqual(; val=magic_constant)) + @constraint(model, X[:, i] in AllEqual(; val=magic_constant)) + end + @constraint(model, [X[i, i] for i in 1:n] in AllEqual(; val=magic_constant)) + @constraint(model, [X[i, n+1-i] for i in 1:n] in AllEqual(; val=magic_constant)) + + return model, X +end +``` \ No newline at end of file diff --git a/examples/quadratic_assignment.md b/examples/quadratic_assignment.md new file mode 100644 index 0000000..3ebce51 --- /dev/null +++ b/examples/quadratic_assignment.md @@ -0,0 +1,25 @@ +# Example for the Quadratic Assignment Problem +```julia +""" +qap(n, W, D) + +Create a JuMP model for the Quadratic Assignment Problem (QAP). + +## Arguments +- `n::Int`: problem size. +- `W::Matrix`: flow matrix. +- `D::Matrix`: distance matrix. +""" +function qap(n, W, D) + model = JuMP.Model(CBLS.Optimizer) + + @variable(model, 1 ≤ X[1:n] ≤ n, Int) + @constraint(model, X in AllDifferent()) + + Σwd = p -> sum(sum(W[p[i], p[j]] * D[i, j] for j in 1:n) for i in 1:n) + + @objective(model, Min, ScalarFunction(Σwd)) + + return model, X +end +``` \ No newline at end of file diff --git a/src/ConstraintsTranslator.jl b/src/ConstraintsTranslator.jl index 1618a91..2defd32 100644 --- a/src/ConstraintsTranslator.jl +++ b/src/ConstraintsTranslator.jl @@ -16,6 +16,7 @@ export Prompt export PromptTemplate export GroqLLM export GoogleLLM +export LlamaCppLLM export get_completion export stream_completion export read_template @@ -29,4 +30,4 @@ include("llm.jl") include("parsing.jl") include("translate.jl") -end +end \ No newline at end of file diff --git a/src/llm.jl b/src/llm.jl index 8446a2e..2ce08da 100644 --- a/src/llm.jl +++ b/src/llm.jl @@ -42,6 +42,24 @@ struct GoogleLLM <: AbstractLLM end end +""" + LlamaCppLLM +Structure encapsulating the parameters for accessing the llama.cpp server API. +- `api_key`: an optional API key for accessing the server +- `url`: the URL of the llama.cpp server OpenAI API endpoint (e.g., http://localhost:8080) +NOTE: we do not apply the appropriate chat templates to the prompt. +This must be handled either in an external code path or by the server. +""" +struct LlamaCppLLM <: AbstractLLM + api_key::String + url::String + + function LlamaCppLLM(url::String) + api_key = get(ENV, "LLAMA_CPP_API_KEY", "no-key") + new(api_key, url) + end +end + """ get_completion(llm::GroqLLM, prompt::Prompt) Returns a completion for the given prompt using the Groq LLM API. @@ -75,7 +93,7 @@ function get_completion(llm::GoogleLLM, prompt::Prompt) ] body = JSON3.write(Dict( "contents" => Dict( - "parts" => Dict("text" => prompt.system * prompt.user) + "parts" => Dict("text" => join([prompt.system, prompt.user], "\n")) ), )) response = HTTP.post(url, headers, body) @@ -83,6 +101,80 @@ function get_completion(llm::GoogleLLM, prompt::Prompt) return body["candidates"][1]["content"]["parts"][1]["text"] end +""" + get_completion(llm::LlamaCppLLM, prompt::Prompt) +Returns a completion for the given prompt using the llama.cpp server API. +""" +function get_completion(llm::LlamaCppLLM, prompt::Prompt) + url = join([llm.url, "v1/chat/completions"], "/") + header = [ + "Authorization" => "Bearer $(llm.api_key)", + "Content-Type" => "application/json", + ] + body = JSON3.write(Dict( + "messages" => [ + Dict("role" => "system", "content" => prompt.system), + Dict("role" => "user", "content" => prompt.user), + ], + )) + response = HTTP.post(url, header, body) + body = JSON3.read(response.body) + return body["choices"][1]["message"]["content"] +end + +""" + stream_completion(llm::LlamaCppLLM, prompt::Prompt) +Returns a completion for the given prompt using the Groq LLM API. +The completion is streamed to the terminal as it is generated. +""" +function stream_completion(llm::LlamaCppLLM, prompt::Prompt) + url = join([llm.url, "v1/chat/completions"], "/") + headers = [ + "Authorization" => "Bearer $(llm.api_key)", + "Content-Type" => "application/json", + ] + body = JSON3.write(Dict( + "messages" => [ + Dict("role" => "system", "content" => prompt.system), + Dict("role" => "user", "content" => prompt.user), + ], + "stream" => true, + )) + + accumulated_content = "" + event_buffer = "" + + HTTP.open(:POST, url, headers; body = body) do io + write(io, body) + HTTP.closewrite(io) + HTTP.startread(io) + while !eof(io) + chunk = String(readavailable(io)) + events = split(chunk, "\n\n") + if !endswith(event_buffer, "\n\n") + event_buffer = events[end] + events = events[1:(end - 1)] + else + event_buffer = "" + end + events = join(events, "\n") + for line in eachmatch(r"(?<=data: ).*", events, overlap = true) + if line.match == "[DONE]" + print("\n") + break + end + message = JSON3.read(line.match) + if !isempty(message["choices"][1]["delta"]) + print(message["choices"][1]["delta"]["content"]) + accumulated_content *= message["choices"][1]["delta"]["content"] + end + end + end + HTTP.closeread(io) + end + return accumulated_content +end + """ stream_completion(llm::GroqLLM, prompt::Prompt) Returns a completion for the given prompt using the Groq LLM API. @@ -107,6 +199,8 @@ function stream_completion(llm::GroqLLM, prompt::Prompt) HTTP.open(:POST, GROQ_URL, headers; body = body) do io write(io, body) + HTTP.closewrite(io) + HTTP.startread(io) while !eof(io) chunk = String(readavailable(io)) events = split(chunk, "\n\n") @@ -129,6 +223,7 @@ function stream_completion(llm::GroqLLM, prompt::Prompt) end end end + HTTP.closeread(io) end return accumulated_content end @@ -154,6 +249,8 @@ function stream_completion(llm::GoogleLLM, prompt::Prompt) HTTP.open(:POST, url, headers; body = body) do io write(io, body) + HTTP.closewrite(io) + HTTP.startread(io) while !eof(io) chunk = String(readavailable(io)) line = match(r"(?<=data: ).*", chunk) @@ -165,6 +262,7 @@ function stream_completion(llm::GoogleLLM, prompt::Prompt) print(message["candidates"][1]["content"]["parts"][1]["text"]) accumulated_content *= String(message["candidates"][1]["content"]["parts"][1]["text"]) end + HTTP.closeread(io) end return accumulated_content end diff --git a/src/template.jl b/src/template.jl index 6c484d5..caa99b1 100644 --- a/src/template.jl +++ b/src/template.jl @@ -76,10 +76,19 @@ TODO: validate the JSON data against a schema to ensure it is valid before parsi # Raises - `ErrorException`: if the template does not match the specification format. """ -function read_template(data_path::String)::PromptTemplate +function read_template(data_path::String) file_content = read(data_path, String) data = JSON3.read(file_content) + package_path = pkgdir(@__MODULE__) + schema_path = joinpath(package_path, "templates", "TemplateSchema.json") + schema_content = read(schema_path, String) + schema = JSONSchema.Schema(JSON3.read(schema_content)) + + if !isnothing(JSONSchema.validate(schema, data)) + error("Invalid template format.") + end + metadata = nothing system = nothing user = nothing @@ -102,15 +111,9 @@ function read_template(data_path::String)::PromptTemplate item["content"], item["variables"], ) - else - error("Unknown message type: $_type") end end - if isnothing(metadata) || isnothing(system) || isnothing(user) - error("Template must contain metadata, system, and user messages") - end - return PromptTemplate(metadata, system, user) end diff --git a/src/templates/ExtractStructure.json b/src/templates/ExtractStructure.json deleted file mode 100644 index 1993195..0000000 --- a/src/templates/ExtractStructure.json +++ /dev/null @@ -1,23 +0,0 @@ -[ - { - "content": "Template Metadata", - "description": "Instructs the LLM to extract the high-level structure of the optimization problem based on the given description.", - "version": "2.0", - "source": "", - "_type": "metadatamessage" - }, - { - "content": "You are an AI assistant specialized in modeling Constraint Programming problems. You have extensive knowledge of the constraints commonly used in Constraint Programming, especially XCSP3 constraints.\nYour task is to examine a given problem description and extract key structural information. You must focus on the general form of the problem rather than specific instances or numerical values. Provide your analysis in the following format:\n\n1. Problem Description:\n- Summarize the problem statement and all of its specifications.\n\n2. Parameter Sets:\n- Identify sets of known quantities (i.e., data) given in the problem description. These are fixed inputs to the problem, not determined by the optimization process.\n- For each set of parameters:\n* Provide a descriptive name for the set.\n\n*Define a notation for the set using subscripts (e.g., a_ijk), specifying the meaning and the range of each index.\n\n3. Decision Variables:\n- Identify the key sets of decisions that need to be made. For each set of decision variables:\n* Provide a descriptive name for the set.\n* Specify the domain (possible values) for elements in this set, which can be either binary, integer or continuous.\n*Define a notation for the set using subscripts (e.g., x_ijk), specifying the meaning and the range of each index.\n\n4. Problem Type: determine whether the problem is a satisfaction or an optimization problem. If it is an optimization problem, provide a Description of the objective function and a symbolic Expression, consistently with the notation already defined. Otherwise, just concisely state that the problem is a satisfaction problem.\n\n5. Constraints. Express the problem's constraint using user-provided Core Constraints. For each constraint:\n* Provide a short textual description\n*Provide the Core Constraint enforcing the constraint. List of core constraints:\n{{constraints}}\n\nIMPORTANT: think step-by-step: a good problem formulation should be clear and concise, with the fewest possible variables and constraints. You must not refer to constraints outside the Core Constraints list. You must output the requested information only.", - "variables": [ - "constraints" - ], - "_type": "systemmessage" - }, - { - "content": "# Problem description: {{description}}", - "variables": [ - "description" - ], - "_type": "usermessage" - } -] \ No newline at end of file diff --git a/src/templates/JumpifyModel.json b/src/templates/JumpifyModel.json deleted file mode 100644 index 4cf4c2b..0000000 --- a/src/templates/JumpifyModel.json +++ /dev/null @@ -1,24 +0,0 @@ -[ - { - "content": "Template Metadata", - "description": "Instructs the LLM to convert a structured textual description of a Constraint Programming problem into a JuMP model to be solved with LocalSearchSolvers.jl.", - "version": "2.0", - "source": "", - "_type": "metadatamessage" - }, - { - "content": "You are an AI assistant specialized in modeling Constraint Programming problems. Your task is to examine a given description of a Constraint Programming model and provide a code implementation in Julia, using JuMP and the CBLS solver.\nConstraints MUST be expressed with the following JuMP syntax: `@constraint(model, x in ConstraintName(kwargs)`, where `x` is a vector of variables, `ConstraintName` is the name of the constraint in camel-case (example: all different constraint -> AllDifferent()), and `kwargs` are the keyword arguments for the constraint (example: Sum(op=<=, val=10).\nIMPORTANT: 1. Output only code with no additional text.\n2. You must write a docstring for the code.\n3. The code must be succinct and capture all the described constraints.\n4. You MUST use the provide syntax to express constraints. Do NOT express constraints in algebraic form. \n Example output for the Magic Square Problem:\n{{example_magic_square}}\n\nExample output for the Quadratic Assignment problem:\n{{example_qap}}.", - "variables": [ - "example_magic_square", - "example_qap" - ], - "_type": "systemmessage" - }, - { - "content": "{{description}}", - "variables": [ - "description" - ], - "_type": "usermessage" - } -] \ No newline at end of file diff --git a/src/translate.jl b/src/translate.jl new file mode 100644 index 0000000..16cb738 --- /dev/null +++ b/src/translate.jl @@ -0,0 +1,110 @@ +""" + extract_structure(model <: AbstractLLM, description <: AbstractString) +Extracts the parameters, decision variables and constraints of an optimization problem given a natural-language `description`. +Returns a plaintext Markdown-formatted document containing the above information. +""" +function extract_structure( + model::AbstractLLM, + description::AbstractString, + constraints::AbstractString, +) + package_path::String = pkgdir(@__MODULE__) + template_path = joinpath(package_path, "templates", "ExtractStructure.json") + template = read_template(template_path) + prompt = format_template(template; description, constraints) + response = stream_completion(model, prompt) + + options = [ + "Accept the response", + "Edit the response", + "Try again with a different prompt", + "Try again with the same prompt", + ] + menu = RadioMenu(options; pagesize = 4) + + while true + choice = request("What do you want to do?", menu) + if choice == 1 + break + elseif choice == 2 + response = edit_in_vim(response) + println(response) + elseif choice == 3 + description = edit_in_vim(description) + prompt = format_template(template; description, constraints) + response = stream_completion(model, prompt) + elseif choice == 4 + response = stream_completion(model, prompt) + elseif choice == -1 + InterruptException() + end + end + return response +end + +function jumpify_model( + model::AbstractLLM, + description::AbstractString, + examples::AbstractString, +) + package_path::String = pkgdir(@__MODULE__) + template_path = joinpath(package_path, "templates", "JumpifyModel.json") + template = read_template(template_path) + prompt = format_template(template; description, examples) + response = stream_completion(model, prompt) + + options = [ + "Accept the response", + "Edit the response", + "Try again with a different prompt", + "Try again with the same prompt", + ] + menu = RadioMenu(options; pagesize = 4) + while true + choice = request("What do you want to do?", menu) + if choice == 1 + break + elseif choice == 2 + response = edit_in_vim(response) + println(response) + elseif choice == 3 + description = edit_in_vim(description) + prompt = format_template(template; description, examples) + response = stream_completion(model, prompt) + elseif choice == 4 + response = stream_completion(model, prompt) + elseif choice == -1 + InterruptException() + end + end + return response +end + +""" + translate(description::String, model::String) +Translate the natural-language `description` of an optimization problem into a Constraint Programming model +by querying the Large Language Model `model`. +""" +function translate(model::AbstractLLM, description::AbstractString) + constraints = String[] + for (name, cons) in USUAL_CONSTRAINTS + push!(constraints, "$(name): $(lstrip(cons.description))") + end + constraints = join(constraints, "\n") + + structure = extract_structure(model, description, constraints) + + package_path::String = pkgdir(@__MODULE__) + examples_path = joinpath(package_path, "examples") + examples_files = filter(x -> endswith(x, ".md"), readdir(examples_path)) + examples = [] + for file in examples_files + example = read(joinpath(examples_path, file), String) + push!(examples, example) + end + examples = join(examples, "\n") + + response = jumpify_model(model, structure, examples) + + return parse_code(response)["julia"] +end diff --git a/templates/ExtractStructure.json b/templates/ExtractStructure.json new file mode 100644 index 0000000..8a2feeb --- /dev/null +++ b/templates/ExtractStructure.json @@ -0,0 +1,23 @@ +[ + { + "content": "Template Metadata", + "description": "Instructs the LLM to extract the high-level structure of the optimization problem based on the given description.", + "version": "2.0", + "source": "", + "_type": "metadatamessage" + }, + { + "content": "You are an AI assistant specialized in modeling Constraint Programming (CP) problems. You have extensive knowledge of the XCSP3 Constraints and of the most used modeling patterns in Constraint Programming.\nYour task is to examine a given problem description and extract key structural information. Provide your analysis in the following format:\n\n1. Problem Description:\n- Summarize the problem statement and all of its specifications.\n\n2. Input data. Describe the format of the input data of the optimization problem. If no format is specified by the user, make sensible assumptions about one or multiple .csv files representing the problem inputs, and very concisely describe their headers.\n3. Parameter Sets:\n- Identify sets of known quantities given in the problem description. These are fixed inputs to the problem, not determined by the optimization process.\n- For each set of parameters:\n* Provide a descriptive name for the set.\n\n*Define a symbolic notation for the set using subscripts (e.g., a_ijk), specifying the meaning and the range of each index.\n\n3. Decision Variables:\n- Identify the key sets of decisions that need to be made. For each set of decision variables:\n* Provide a descriptive name for the set.\n* Specify the domain (possible values) for elements in this set, which can be either binary, integer or continuous.\n*Define a notation for the set using subscripts (e.g., x_ijk), specifying the meaning and the range of each index.\n\n4. Problem Type: determine whether the problem is a satisfaction or an optimization problem. If it is an optimization problem, provide: - a description of the objective function; - a symbolic Expression, consistently with the notation already defined. Otherwise, if the problem is a satisfaction problem, concisely state this fact.\n\n5. Constraints. Express the problem's constraint using user-provided Core Constraints. You must prefer using CP-oriented global constraints when possible. For each constraint:\n* Write a short description\n*Write the name (only the name) of Core Constraint(s) enforcing the constraint.\n*Write the scope of the constraint, that is, the indexes of the variables appearing in the constraint.\n\nList of core constraints:\n{{constraints}}\n\nIMPORTANT: - Prioritize Constraint Programming formulations over MIP formulations.\n-You must use as few variables and constraints as possible: you must avoid useless or redundant constraints.\n-You must not refer to constraints outside the Core Constraints list.\n-You must make sure that the Core Constraints are used with the appropriate arguments.\n-You must output the requested information only.", + "variables": [ + "constraints" + ], + "_type": "systemmessage" + }, + { + "content": "# Problem description: {{description}}", + "variables": [ + "description" + ], + "_type": "usermessage" + } +] \ No newline at end of file diff --git a/templates/JumpifyModel.json b/templates/JumpifyModel.json new file mode 100644 index 0000000..de63bda --- /dev/null +++ b/templates/JumpifyModel.json @@ -0,0 +1,23 @@ +[ + { + "content": "Template Metadata", + "description": "Instructs the LLM to convert a structured textual description of a Constraint Programming problem into a JuMP model to be solved with LocalSearchSolvers.jl.", + "version": "2.0", + "source": "", + "_type": "metadatamessage" + }, + { + "content": "You are an AI assistant specialized in modeling Constraint Programming problems. Your task is to examine a given description of a Constraint Programming model and provide a code implementation in Julia, using JuMP and the CBLS solver. The code MUST: 1) Read the input data from external files into data structures according to the specifications provided in the description, using the appropriate Julia packages (e.g., DataFrames.jl, CSV.jl, etc.), 2) build the model, and 3) return the model.\nConstraints MUST be expressed with the following JuMP syntax: `@constraint(model, x in ConstraintName(kwargs))`, where `x` is a vector of variables, `ConstraintName` is the name of the constraint in camel-case (example: all different constraint -> AllDifferent()), and `kwargs` are the keyword arguments for the constraint (example: Sum(op=<=, val=10).\nIMPORTANT: 1. Output only the required function with no additional text or usage examples.\n2. You must write a docstring for the code.\n3. The code must be succinct and capture all the described constraints.\n4. You MUST use the provide syntax to express constraints. Do NOT express constraints in algebraic form.\n\n{{examples}}", + "variables": [ + "examples" + ], + "_type": "systemmessage" + }, + { + "content": "{{description}}", + "variables": [ + "description" + ], + "_type": "usermessage" + } +] \ No newline at end of file diff --git a/src/templates/README.md b/templates/README.md similarity index 100% rename from src/templates/README.md rename to templates/README.md diff --git a/templates/TemplateSchema.json b/templates/TemplateSchema.json new file mode 100644 index 0000000..1ed6fa2 --- /dev/null +++ b/templates/TemplateSchema.json @@ -0,0 +1,96 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "array", + "properties": { + "messages": { + "type": "array", + "minItems": 3, + "maxItems": 3, + "items": [ + { + "type": "object", + "properties": { + "content": { + "type": "string" + }, + "description": { + "type": "string" + }, + "version": { + "type": "string" + }, + "source": { + "type": "string" + }, + "_type": { + "type": "string", + "enum": [ + "metadatamessage" + ] + } + }, + "required": [ + "content", + "description", + "version", + "source", + "_type" + ] + }, + { + "type": "object", + "properties": { + "content": { + "type": "string" + }, + "variables": { + "type": "array", + "items": { + "type": "string" + } + }, + "_type": { + "type": "string", + "enum": [ + "systemmessage" + ] + } + }, + "required": [ + "content", + "variables", + "_type" + ] + }, + { + "type": "object", + "properties": { + "content": { + "type": "string" + }, + "variables": { + "type": "array", + "items": { + "type": "string" + } + }, + "_type": { + "type": "string", + "enum": [ + "usermessage" + ] + } + }, + "required": [ + "content", + "variables", + "_type" + ] + } + ] + } + }, + "required": [ + "messages" + ] +} \ No newline at end of file diff --git a/test/utils.jl b/test/parsing.jl similarity index 100% rename from test/utils.jl rename to test/parsing.jl diff --git a/test/runtests.jl b/test/runtests.jl index b735084..249f35d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,5 +9,5 @@ using TestItems include("Aqua.jl") include("JET.jl") include("TestItemRunner.jl") - include("utils.jl") + include("parsing.jl") end From 64ee133527871c30d2354491f6ff2001b4c6fb1d Mon Sep 17 00:00:00 2001 From: Azzaare Date: Thu, 19 Sep 2024 11:14:04 +0900 Subject: [PATCH 09/12] Fix editor being default to vim. Only work with terminal based editor for now --- .gitignore | 8 + Manifest.toml | 623 ----------------------------------- Project.toml | 2 + README.md | 8 +- src/ConstraintsTranslator.jl | 4 +- src/parsing.jl | 7 +- src/translate.jl | 8 +- 7 files changed, 27 insertions(+), 633 deletions(-) create mode 100644 .gitignore delete mode 100644 Manifest.toml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f391a33 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +*.jl.*.cov +*.jl.cov +*.jl.mem +.DS_Store +.gitignore +.vscode/* +/Manifest.toml +Manifest.toml diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index 4af0f2b..0000000 --- a/Manifest.toml +++ /dev/null @@ -1,623 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.5" -manifest_format = "2.0" -project_hash = "e3750de370b9d1d33722c2ed1083743cce2dcb1e" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.BitFlags]] -git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" -uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" -version = "0.1.9" - -[[deps.CSTParser]] -deps = ["Tokenize"] -git-tree-sha1 = "0157e592151e39fa570645e2b2debcdfb8a0f112" -uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" -version = "3.4.3" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.6" - -[[deps.CommonMark]] -deps = ["Crayons", "JSON", "PrecompileTools", "URIs"] -git-tree-sha1 = "532c4185d3c9037c0237546d817858b23cf9e071" -uuid = "a80b9123-70ca-4bc0-993e-6e3bcb318db6" -version = "0.8.12" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.16.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" - -[[deps.CompositionalNetworks]] -deps = ["ConstraintCommons", "ConstraintDomains", "Dictionaries", "Distances", "JuliaFormatter", "OrderedCollections", "Random", "TestItems", "Unrolled"] -git-tree-sha1 = "42ea78627a970cc0f4d0707fb87c29a5892a65cc" -uuid = "4b67e4b5-442d-4ef5-b760-3f5df3a57537" -version = "0.5.9" - -[[deps.ConcurrentUtilities]] -deps = ["Serialization", "Sockets"] -git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" -uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.4.2" - -[[deps.ConstraintCommons]] -deps = ["Dictionaries", "TestItems"] -git-tree-sha1 = "779227189854f846de5f72b518e50dda14c7886b" -uuid = "e37357d9-0691-492f-a822-e5ea6a920954" -version = "0.2.3" - -[[deps.ConstraintDomains]] -deps = ["ConstraintCommons", "Intervals", "PatternFolds", "StatsBase", "TestItems"] -git-tree-sha1 = "02380c829c947c0579864c51affa1646a170d037" -uuid = "5800fd60-8556-4464-8d61-84ebf7a0bedb" -version = "0.3.13" - -[[deps.Constraints]] -deps = ["CompositionalNetworks", "ConstraintCommons", "ConstraintDomains", "DataFrames", "Dictionaries", "MacroTools", "PrettyTables", "TestItems"] -git-tree-sha1 = "8256d3a55ad8e7be10fa4e18325ad39dfbd24c68" -uuid = "30f324ab-b02d-43f0-b619-e131c61659f7" -version = "0.5.7" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" - -[[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.Dictionaries]] -deps = ["Indexing", "Random", "Serialization"] -git-tree-sha1 = "35b66b6744b2d92c778afd3a88d2571875664a2a" -uuid = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" -version = "0.4.2" - -[[deps.Distances]] -deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.11" - - [deps.Distances.extensions] - DistancesChainRulesCoreExt = "ChainRulesCore" - DistancesSparseArraysExt = "SparseArrays" - - [deps.Distances.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.ExceptionUnwrapping]] -deps = ["Test"] -git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" -uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" -version = "0.1.10" - -[[deps.ExprTools]] -git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" -uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.10" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.Glob]] -git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" -uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" -version = "1.3.1" - -[[deps.HTTP]] -deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" -uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.8" - -[[deps.Indexing]] -git-tree-sha1 = "ce1566720fd6b19ff3411404d4b977acd4814f9f" -uuid = "313cdc1a-70c2-5d6a-ae34-0150d3930a38" -version = "1.1.1" - -[[deps.InlineStrings]] -git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" -uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.4.2" - - [deps.InlineStrings.extensions] - ArrowTypesExt = "ArrowTypes" - ParsersExt = "Parsers" - - [deps.InlineStrings.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.Intervals]] -deps = ["Dates", "Printf", "RecipesBase", "Serialization", "TimeZones"] -git-tree-sha1 = "ac0aaa807ed5eaf13f67afe188ebc07e828ff640" -uuid = "d8418881-c3e1-53bb-8760-2df7ec849ed5" -version = "1.10.0" - -[[deps.InvertedIndices]] -git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.3.0" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.6.0" - -[[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" - -[[deps.JSON3]] -deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] -git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" -uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.14.0" - - [deps.JSON3.extensions] - JSON3ArrowExt = ["ArrowTypes"] - - [deps.JSON3.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - -[[deps.JSONSchema]] -deps = ["Downloads", "JSON", "JSON3", "URIs"] -git-tree-sha1 = "243f1cdb476835d7c249deb9f29ad6b7827da7d3" -uuid = "7d188eb4-7ad8-530c-ae41-71a32a6d4692" -version = "1.4.1" - -[[deps.JuliaFormatter]] -deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "PrecompileTools", "TOML", "Tokenize"] -git-tree-sha1 = "bb4696471330275adfd6c78c6173f276e8c067aa" -uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899" -version = "1.0.60" - -[[deps.LaTeXStrings]] -git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" -uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.1" - -[[deps.Lazy]] -deps = ["MacroTools"] -git-tree-sha1 = "1370f8202dac30758f3c345f9909b97f53d87d3f" -uuid = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" -version = "0.15.1" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.28" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.LoggingExtras]] -deps = ["Dates", "Logging"] -git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" -uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.3" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] -git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" -uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.1.9" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.2.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.Mocking]] -deps = ["Compat", "ExprTools"] -git-tree-sha1 = "2c140d60d7cb82badf06d8783800d0bcd1a7daa2" -uuid = "78c3b35d-d492-501b-9361-3d52fe80e533" -version = "0.8.1" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" - -[[deps.OpenSSL]] -deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" -uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.4.3" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1b35263570443fdd9e76c76b7062116e2f374ab8" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.15+0" - -[[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.1" - -[[deps.PatternFolds]] -deps = ["Intervals", "Lazy", "Random", "Reexport", "TestItemRunner", "TestItems"] -git-tree-sha1 = "21fb4c221aca131474a886a015a3cd5b1a42b6d2" -uuid = "c18a7f1d-76ad-4ce4-950d-5419b888513b" -version = "0.2.5" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - -[[deps.PooledArrays]] -deps = ["DataAPI", "Future"] -git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.3" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.PrettyTables]] -deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.RecipesBase]] -deps = ["PrecompileTools"] -git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" -uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -version = "1.3.4" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Scratch]] -deps = ["Dates"] -git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.1" - -[[deps.SentinelArrays]] -deps = ["Dates", "Random"] -git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" -uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.5" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" -uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.2.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.3" - -[[deps.StringManipulation]] -deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" -uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" - -[[deps.StructTypes]] -deps = ["Dates", "UUIDs"] -git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8" -uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.11.0" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TZJData]] -deps = ["Artifacts"] -git-tree-sha1 = "36b40607bf2bf856828690e097e1c799623b0602" -uuid = "dc5dba14-91b3-4cab-a142-028a31da12f7" -version = "1.3.0+2024b" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.12.0" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TestItemRunner]] -deps = ["Pkg", "TOML", "Test", "TestItems", "UUIDs"] -git-tree-sha1 = "29647c5398be04a1d697265ba385bdf3f623c993" -uuid = "f8b46487-2199-4994-9208-9a1283c18c0a" -version = "1.0.5" - -[[deps.TestItems]] -git-tree-sha1 = "42fd9023fef18b9b78c8343a4e2f3813ffbcefcb" -uuid = "1c621080-faea-4a02-84b6-bbd5e436b8fe" -version = "1.0.0" - -[[deps.TimeZones]] -deps = ["Dates", "Downloads", "InlineStrings", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"] -git-tree-sha1 = "8323074bc977aa85cf5ad71099a83ac75b0ac107" -uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53" -version = "1.18.1" -weakdeps = ["RecipesBase"] - - [deps.TimeZones.extensions] - TimeZonesRecipesBaseExt = "RecipesBase" - -[[deps.Tokenize]] -git-tree-sha1 = "468b4685af4abe0e9fd4d7bf495a6554a6276e75" -uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" -version = "0.5.29" - -[[deps.TranscodingStreams]] -git-tree-sha1 = "e84b3a11b9bece70d14cce63406bbc79ed3464d2" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.2" - -[[deps.URIs]] -git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.1" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.Unrolled]] -deps = ["MacroTools"] -git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" -uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" -version = "0.1.5" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.11.0+0" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index 8e1a3ed..8040150 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.0.1" [deps] Constraints = "30f324ab-b02d-43f0-b619-e131c61659f7" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" JSONSchema = "7d188eb4-7ad8-530c-ae41-71a32a6d4692" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" @@ -15,6 +16,7 @@ TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" Aqua = "0.8" Constraints = "0.5" HTTP = "1.10" +InteractiveUtils = "1.11.0" JET = "0.9" JSON3 = "1" JSONSchema = "1" diff --git a/README.md b/README.md index 83d290b..b955169 100644 --- a/README.md +++ b/README.md @@ -24,8 +24,12 @@ using ConstraintsTranslator llm = GoogleLLM("gemini-1.5-pro") +# Optional setup of a terminal editor (uncomment and select a viable editor on your machine such as vim, nano, emacs, ...) +ENV["EDITOR"] = "vim" + + description = """ -We need to determine the shortest possible route for a salesman who must visit a set of cities exactly once and return to the starting city. +We need to determine the shortest possible route for a salesman who must visit a set of cities exactly once and return to the starting city. The objective is to minimize the total travel distance while ensuring that each city is visited exactly once. Example input data: @@ -43,7 +47,7 @@ CityA,CityC,8 response = translate(llm, description) ``` -The `translate` function will first produce a Markdown representation of the problem, and then return the generated Julia code for parsing the input data and building the model. +The `translate` function will first produce a Markdown representation of the problem, and then return the generated Julia code for parsing the input data and building the model. This example uses Google Gemini as an LLM. You will need an API key and a model id to access proprietary API endpoints. Use `help?>` in the Julia REPL to learn more about the available models. diff --git a/src/ConstraintsTranslator.jl b/src/ConstraintsTranslator.jl index 2defd32..902270d 100644 --- a/src/ConstraintsTranslator.jl +++ b/src/ConstraintsTranslator.jl @@ -9,6 +9,8 @@ import Constraints: USUAL_CONSTRAINTS import REPL using REPL.TerminalMenus +import InteractiveUtils + # Exports export AbstractLLM export parse_code @@ -30,4 +32,4 @@ include("llm.jl") include("parsing.jl") include("translate.jl") -end \ No newline at end of file +end diff --git a/src/parsing.jl b/src/parsing.jl index e322b47..3877700 100644 --- a/src/parsing.jl +++ b/src/parsing.jl @@ -32,11 +32,12 @@ end Edits the input string `s` in a temporary file using the Vim editor. Returns the modified string after the editor is closed. """ -function edit_in_vim(initial_text::String) +function edit_in_editor(initial_text::String; editor = "vim") temp_filename = tempname() write(temp_filename, initial_text) - run(`vim $temp_filename`) + InteractiveUtils.edit(temp_filename) + # run(`vim $temp_filename` edited_text = read(temp_filename, String) rm(temp_filename) return edited_text -end \ No newline at end of file +end diff --git a/src/translate.jl b/src/translate.jl index 16cb738..3c3cfd9 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -27,10 +27,10 @@ function extract_structure( if choice == 1 break elseif choice == 2 - response = edit_in_vim(response) + response = edit_in_editor(response) println(response) elseif choice == 3 - description = edit_in_vim(description) + description = edit_in_editor(description) prompt = format_template(template; description, constraints) response = stream_completion(model, prompt) elseif choice == 4 @@ -65,10 +65,10 @@ function jumpify_model( if choice == 1 break elseif choice == 2 - response = edit_in_vim(response) + response = edit_in_editor(response) println(response) elseif choice == 3 - description = edit_in_vim(description) + description = edit_in_editor(description) prompt = format_template(template; description, examples) response = stream_completion(model, prompt) elseif choice == 4 From efcf7a8685e5fcd4612901f41092fda8237118c6 Mon Sep 17 00:00:00 2001 From: Nicola Di Cicco <93935338+nicoladicicco@users.noreply.github.com> Date: Thu, 19 Sep 2024 12:30:06 +0900 Subject: [PATCH 10/12] Add autofixing of syntax errors --- Project.toml | 2 +- src/ConstraintsTranslator.jl | 7 ++-- src/translate.jl | 62 ++++++++++++++++++++++++++++------- templates/FixJuliaSyntax.json | 22 +++++++++++++ 4 files changed, 76 insertions(+), 17 deletions(-) create mode 100644 templates/FixJuliaSyntax.json diff --git a/Project.toml b/Project.toml index 8040150..c1ead50 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,7 @@ TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" Aqua = "0.8" Constraints = "0.5" HTTP = "1.10" -InteractiveUtils = "1.11.0" +InteractiveUtils = "1" JET = "0.9" JSON3 = "1" JSONSchema = "1" diff --git a/src/ConstraintsTranslator.jl b/src/ConstraintsTranslator.jl index 902270d..30a5fd8 100644 --- a/src/ConstraintsTranslator.jl +++ b/src/ConstraintsTranslator.jl @@ -1,15 +1,14 @@ module ConstraintsTranslator # Imports +import Constraints: USUAL_CONSTRAINTS import HTTP +import InteractiveUtils import JSONSchema import JSON3 -import TestItems: @testitem -import Constraints: USUAL_CONSTRAINTS import REPL using REPL.TerminalMenus - -import InteractiveUtils +import TestItems: @testitem # Exports export AbstractLLM diff --git a/src/translate.jl b/src/translate.jl index 3c3cfd9..9b4d490 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -1,7 +1,8 @@ """ extract_structure(model <: AbstractLLM, description <: AbstractString) -Extracts the parameters, decision variables and constraints of an optimization problem given a natural-language `description`. -Returns a plaintext Markdown-formatted document containing the above information. +Extracts the parameters, decision variables and constraints of an optimization problem +given a natural-language `description`. +Returns a Markdown-formatted text containing the above information. """ function extract_structure( model::AbstractLLM, @@ -42,6 +43,15 @@ function extract_structure( return response end +""" + jumpify_model(model::AbstractLLM, description::AbstractString, examples::AbstractString) +Translates the natural language `description` of an optimization problem into a JuMP constraints +programming model to be solved with CBL by querying the Large Language Model `model`. +The `examples` are snippets from `ConstraintModels.jl` used as in-context examples to the LLM. +To work optimally, the model expects the `description` to be a structured Markdown-formatted +description as the ones generated by `extract_structure`. +Returns a Markdown-formatted text containing Julia code in a code block. +""" function jumpify_model( model::AbstractLLM, description::AbstractString, @@ -53,14 +63,25 @@ function jumpify_model( prompt = format_template(template; description, examples) response = stream_completion(model, prompt) - options = [ - "Accept the response", - "Edit the response", - "Try again with a different prompt", - "Try again with the same prompt", - ] - menu = RadioMenu(options; pagesize = 4) while true + code = parse_code(response)["julia"] + parsed_expr = Meta.parse(code, raise = false) + error_message = "" + if parsed_expr.head == :incomplete || parsed_expr.head == :error + parse_error = parsed_expr.args[1] + error_message = string(parse_error) + end + options = [ + "Accept the response", + "Edit the response", + "Try again with a different prompt", + "Try again with the same prompt", + ] + if !isempty(error_message) + @warn "The generated Julia code has one or more syntax errors!" + push!(options, "Fix syntax errors") + end + menu = RadioMenu(options; pagesize = 5) choice = request("What do you want to do?", menu) if choice == 1 break @@ -73,6 +94,8 @@ function jumpify_model( response = stream_completion(model, prompt) elseif choice == 4 response = stream_completion(model, prompt) + elseif choice == 5 + response = fix_syntax_errors(model, code, error_message) elseif choice == -1 InterruptException() end @@ -81,9 +104,24 @@ function jumpify_model( end """ - translate(description::String, model::String) -Translate the natural-language `description` of an optimization problem into a Constraint Programming model -by querying the Large Language Model `model`. + fix_syntax_errors(model::AbstractLLM, code::AbstractString, error::AbstractString) +Fixes syntax errors in the `code` by querying the Large Language Model `model`, based on +an `error` produced by the Julia parser. +Returns Markdown-formatted text containing the corrected code in a Julia code block. +""" +function fix_syntax_errors(model::AbstractLLM, code::AbstractString, error::AbstractString) + package_path::String = pkgdir(@__MODULE__) + template_path = joinpath(package_path, "templates", "FixJuliaSyntax.json") + template = read_template(template_path) + prompt = format_template(template; code, error) + response = stream_completion(model, prompt) + return response +end + +""" + translate(model::AbstractLLM, description::AbstractString) +Translate the natural-language `description` of an optimization problem into +a Constraint Programming model by querying the Large Language Model `model`. """ function translate(model::AbstractLLM, description::AbstractString) constraints = String[] diff --git a/templates/FixJuliaSyntax.json b/templates/FixJuliaSyntax.json new file mode 100644 index 0000000..dd7be44 --- /dev/null +++ b/templates/FixJuliaSyntax.json @@ -0,0 +1,22 @@ +[ + { + "content": "Template Metadata", + "description": "Instructs the LLM to resolve syntax errors in Julia code.", + "version": "2.0", + "source": "", + "_type": "metadatamessage" + }, + { + "content": "You are an AI assistant specialized in writing Julia code. Your task is to examine a given code snippet alongside an error message related to syntax errors, and provide an updated version of the code snippet with the syntax errors resolved.\nIMPORTANT: 1. You must only fix the syntax errors without changing the functionality of the code. 2. Think step-by-step, first describing the syntax errors in a bulleted list, and then providing the corrected code snippet in a Julia code block.", + "variables": [], + "_type": "systemmessage" + }, + { + "content": "Code: {{code}}\n\nError: {{error}}", + "variables": [ + "code", + "error" + ], + "_type": "usermessage" + } +] \ No newline at end of file From 78be193fe4e5aa0289d4374a373c48699b1ab6f4 Mon Sep 17 00:00:00 2001 From: Nicola Di Cicco <93935338+nicoladicicco@users.noreply.github.com> Date: Thu, 19 Sep 2024 14:10:51 +0900 Subject: [PATCH 11/12] Add default models to llm constructor --- src/llm.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llm.jl b/src/llm.jl index 2ce08da..7ad7d2e 100644 --- a/src/llm.jl +++ b/src/llm.jl @@ -14,7 +14,7 @@ struct GroqLLM <: AbstractLLM api_key::String model_id::String - function GroqLLM(model_id::String) + function GroqLLM(model_id::String = "llama-3.1-8b-instant") api_key = get(ENV, "GROQ_API_KEY", "") if isempty(api_key) error("Environment variable GROQ_API_KEY is not set") @@ -33,7 +33,7 @@ struct GoogleLLM <: AbstractLLM api_key::String model_id::String - function GoogleLLM(model_id::String) + function GoogleLLM(model_id::String = "gemini-1.5-pro") api_key = get(ENV, "GOOGLE_API_KEY", "") if isempty(api_key) error("Environment variable GOOGLE_API_KEY is not set") From 943e5558f40107841f1cefa6487273afce2042f4 Mon Sep 17 00:00:00 2001 From: Nicola Di Cicco <93935338+nicoladicicco@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:15:04 +0900 Subject: [PATCH 12/12] Implementation of non-interactive pipeline --- Project.toml | 1 + src/ConstraintsTranslator.jl | 3 +- src/llm.jl | 2 +- src/parsing.jl | 28 ++++-- src/template.jl | 2 +- src/translate.jl | 155 ++++++++++++++++++++-------------- src/utils.jl | 11 +++ templates/FixJuliaSyntax.json | 2 +- test/Aqua.jl | 8 +- 9 files changed, 138 insertions(+), 74 deletions(-) create mode 100644 src/utils.jl diff --git a/Project.toml b/Project.toml index c1ead50..b7d41e7 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ InteractiveUtils = "1" JET = "0.9" JSON3 = "1" JSONSchema = "1" +REPL = "1" Test = "1" TestItemRunner = "1" TestItems = "1" diff --git a/src/ConstraintsTranslator.jl b/src/ConstraintsTranslator.jl index 30a5fd8..126ff31 100644 --- a/src/ConstraintsTranslator.jl +++ b/src/ConstraintsTranslator.jl @@ -7,7 +7,7 @@ import InteractiveUtils import JSONSchema import JSON3 import REPL -using REPL.TerminalMenus +import REPL.TerminalMenus: RadioMenu, request import TestItems: @testitem # Exports @@ -30,5 +30,6 @@ include("template.jl") include("llm.jl") include("parsing.jl") include("translate.jl") +include("utils.jl") end diff --git a/src/llm.jl b/src/llm.jl index 7ad7d2e..83fd23c 100644 --- a/src/llm.jl +++ b/src/llm.jl @@ -33,7 +33,7 @@ struct GoogleLLM <: AbstractLLM api_key::String model_id::String - function GoogleLLM(model_id::String = "gemini-1.5-pro") + function GoogleLLM(model_id::String = "gemini-1.5-flash") api_key = get(ENV, "GOOGLE_API_KEY", "") if isempty(api_key) error("Environment variable GOOGLE_API_KEY is not set") diff --git a/src/parsing.jl b/src/parsing.jl index 3877700..24ddd68 100644 --- a/src/parsing.jl +++ b/src/parsing.jl @@ -16,17 +16,35 @@ function parse_code(s::String) # Extract the code blocks and their language annotations for m in matches lang = m.captures[1] == "" ? "plain" : m.captures[1] - code = strip(m.captures[2]) - if haskey(code_dict, lang) - code_dict[lang] *= "\n" * code - else - code_dict[lang] = code + code = m.captures[2] + if !isnothing(code) + code = strip(code) + if haskey(code_dict, lang) + code_dict[lang] *= "\n" * code + else + code_dict[lang] = code + end end end return code_dict end +""" + check_syntax_errors(s::String) +Parses the string `s` as Julia code. In the case of syntax errors, it returns the error +message of the parser as a string. Otherwise, it returns an empty string. +""" +function check_syntax_errors(s::String) + parsed_expr = Meta.parse(s, raise = false) + error_message = "" + if parsed_expr.head == :incomplete || parsed_expr.head == :error + parse_error = parsed_expr.args[1] + error_message = string(parse_error) + end + return error_message +end + """ edit_in_vim(s::String) Edits the input string `s` in a temporary file using the Vim editor. diff --git a/src/template.jl b/src/template.jl index caa99b1..595f9c6 100644 --- a/src/template.jl +++ b/src/template.jl @@ -80,7 +80,7 @@ function read_template(data_path::String) file_content = read(data_path, String) data = JSON3.read(file_content) - package_path = pkgdir(@__MODULE__) + package_path = get_package_path() schema_path = joinpath(package_path, "templates", "TemplateSchema.json") schema_content = read(schema_path, String) schema = JSONSchema.Schema(JSON3.read(schema_content)) diff --git a/src/translate.jl b/src/translate.jl index 9b4d490..e88ce17 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -1,3 +1,5 @@ +const MAX_RETRIES::Int = 3 + """ extract_structure(model <: AbstractLLM, description <: AbstractString) Extracts the parameters, decision variables and constraints of an optimization problem @@ -8,36 +10,40 @@ function extract_structure( model::AbstractLLM, description::AbstractString, constraints::AbstractString, + interactive::Bool, ) - package_path::String = pkgdir(@__MODULE__) - template_path = joinpath(package_path, "templates", "ExtractStructure.json") - template = read_template(template_path) - prompt = format_template(template; description, constraints) + package_path = get_package_path() + prompt_template_path = joinpath(package_path, "templates", "ExtractStructure.json") + prompt_template = read_template(prompt_template_path) + + prompt = format_template(prompt_template; description, constraints) response = stream_completion(model, prompt) - options = [ - "Accept the response", - "Edit the response", - "Try again with a different prompt", - "Try again with the same prompt", - ] - menu = RadioMenu(options; pagesize = 4) + if interactive + options = [ + "Accept the response", + "Edit the response", + "Try again with a different prompt", + "Try again with the same prompt", + ] + menu = RadioMenu(options; pagesize = 5) - while true - choice = request("What do you want to do?", menu) - if choice == 1 - break - elseif choice == 2 - response = edit_in_editor(response) - println(response) - elseif choice == 3 - description = edit_in_editor(description) - prompt = format_template(template; description, constraints) - response = stream_completion(model, prompt) - elseif choice == 4 - response = stream_completion(model, prompt) - elseif choice == -1 - InterruptException() + while true + choice = request("What do you want to do?", menu) + if choice == 1 + break + elseif choice == 2 + response = edit_in_editor(response) + println(response) + elseif choice == 3 + description = edit_in_editor(description) + prompt = format_template(prompt_template; description, constraints) + response = stream_completion(model, prompt) + elseif choice == 4 + response = stream_completion(model, prompt) + elseif choice == -1 + InterruptException() + end end end return response @@ -56,48 +62,63 @@ function jumpify_model( model::AbstractLLM, description::AbstractString, examples::AbstractString, + interactive::Bool, ) - package_path::String = pkgdir(@__MODULE__) + package_path = get_package_path() template_path = joinpath(package_path, "templates", "JumpifyModel.json") template = read_template(template_path) prompt = format_template(template; description, examples) response = stream_completion(model, prompt) - while true - code = parse_code(response)["julia"] - parsed_expr = Meta.parse(code, raise = false) - error_message = "" - if parsed_expr.head == :incomplete || parsed_expr.head == :error - parse_error = parsed_expr.args[1] - error_message = string(parse_error) + if interactive + while true + code = parse_code(response)["julia"] + error_message = check_syntax_errors(code) + + options = [ + "Accept the response", + "Edit the response", + "Try again with a different prompt", + "Try again with the same prompt", + ] + if !isempty(error_message) + @warn "The generated Julia code has one or more syntax errors!" + push!(options, "Fix syntax errors") + end + menu = RadioMenu(options; pagesize = 5) + + choice = request("What do you want to do?", menu) + if choice == 1 + break + elseif choice == 2 + response = edit_in_editor(response) + println(response) + elseif choice == 3 + description = edit_in_editor(description) + prompt = format_template(template; description, examples) + response = stream_completion(model, prompt) + elseif choice == 4 + response = stream_completion(model, prompt) + elseif choice == 5 + response = fix_syntax_errors(model, code, error_message) + elseif choice == -1 + InterruptException() + end end - options = [ - "Accept the response", - "Edit the response", - "Try again with a different prompt", - "Try again with the same prompt", - ] + else + code = parse_code(response)["julia"] + error_message = check_syntax_errors(code) if !isempty(error_message) @warn "The generated Julia code has one or more syntax errors!" - push!(options, "Fix syntax errors") - end - menu = RadioMenu(options; pagesize = 5) - choice = request("What do you want to do?", menu) - if choice == 1 - break - elseif choice == 2 - response = edit_in_editor(response) - println(response) - elseif choice == 3 - description = edit_in_editor(description) - prompt = format_template(template; description, examples) - response = stream_completion(model, prompt) - elseif choice == 4 - response = stream_completion(model, prompt) - elseif choice == 5 - response = fix_syntax_errors(model, code, error_message) - elseif choice == -1 - InterruptException() + for _ in 1:MAX_RETRIES + response = fix_syntax_errors(model, code, error_message) + code = parse_code(response)["julia"] + error_message = check_syntax_errors(code) + if isempty(error_message) + break + end + @warn "The generated Julia code has one or more syntax errors!" + end end end return response @@ -110,7 +131,7 @@ an `error` produced by the Julia parser. Returns Markdown-formatted text containing the corrected code in a Julia code block. """ function fix_syntax_errors(model::AbstractLLM, code::AbstractString, error::AbstractString) - package_path::String = pkgdir(@__MODULE__) + package_path = get_package_path() template_path = joinpath(package_path, "templates", "FixJuliaSyntax.json") template = read_template(template_path) prompt = format_template(template; code, error) @@ -122,17 +143,23 @@ end translate(model::AbstractLLM, description::AbstractString) Translate the natural-language `description` of an optimization problem into a Constraint Programming model by querying the Large Language Model `model`. +If `interactive`, the user will be prompted via the command line to inspect the +intermediate outputs of the LLM, and possibly modify them. """ -function translate(model::AbstractLLM, description::AbstractString) +function translate( + model::AbstractLLM, + description::AbstractString, + interactive::Bool = false, +) constraints = String[] for (name, cons) in USUAL_CONSTRAINTS push!(constraints, "$(name): $(lstrip(cons.description))") end constraints = join(constraints, "\n") - structure = extract_structure(model, description, constraints) + structure = extract_structure(model, description, constraints, interactive) - package_path::String = pkgdir(@__MODULE__) + package_path = get_package_path() examples_path = joinpath(package_path, "examples") examples_files = filter(x -> endswith(x, ".md"), readdir(examples_path)) examples = [] @@ -142,7 +169,7 @@ function translate(model::AbstractLLM, description::AbstractString) end examples = join(examples, "\n") - response = jumpify_model(model, structure, examples) + response = jumpify_model(model, structure, examples, interactive) return parse_code(response)["julia"] end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..bc9a4e2 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,11 @@ +""" + get_package_path() +Returns the absolute path of the root directory of `ConstraintsTranslator.jl`. +""" +function get_package_path() + package_path = pkgdir(@__MODULE__) + if isnothing(package_path) + error("The path of the package could not be found. This should never happen!") + end + return package_path +end \ No newline at end of file diff --git a/templates/FixJuliaSyntax.json b/templates/FixJuliaSyntax.json index dd7be44..3ffa472 100644 --- a/templates/FixJuliaSyntax.json +++ b/templates/FixJuliaSyntax.json @@ -7,7 +7,7 @@ "_type": "metadatamessage" }, { - "content": "You are an AI assistant specialized in writing Julia code. Your task is to examine a given code snippet alongside an error message related to syntax errors, and provide an updated version of the code snippet with the syntax errors resolved.\nIMPORTANT: 1. You must only fix the syntax errors without changing the functionality of the code. 2. Think step-by-step, first describing the syntax errors in a bulleted list, and then providing the corrected code snippet in a Julia code block.", + "content": "You are an AI assistant specialized in writing Julia code. Your task is to examine a given code snippet alongside an error message related to syntax errors, and provide an updated version of the code snippet with the syntax errors resolved. \nIMPORTANT: 1. You must only fix the syntax errors without changing the functionality of the code.\n2. Think step-by-step, first describing the syntax errors in a bulleted list, and then providing the corrected code snippet in a Julia code block.\n3. You must report the complete code with the fix.", "variables": [], "_type": "systemmessage" }, diff --git a/test/Aqua.jl b/test/Aqua.jl index d60f23a..164f60b 100644 --- a/test/Aqua.jl +++ b/test/Aqua.jl @@ -1,3 +1,9 @@ @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(ConstraintsTranslator) + Aqua.test_all( + ConstraintsTranslator, + ambiguities = (broken = true,), + deps_compat = false, + piracies = (broken = false,), + unbound_args = (broken = false), + ) end \ No newline at end of file