diff --git a/Project.toml b/Project.toml index b7d41e7..edb60c4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ConstraintsTranslator" uuid = "314c63f5-3dda-4b35-95e7-4cc933f13053" authors = ["Jean-François BAFFIER (@Azzaare)"] -version = "0.0.1" +version = "0.0.2" [deps] Constraints = "30f324ab-b02d-43f0-b619-e131c61659f7" diff --git a/README.md b/README.md index 4c24338..a76fe3a 100644 --- a/README.md +++ b/README.md @@ -55,11 +55,13 @@ CityA,CityB,10 CityA,CityC,8 """ -response = translate(llm, description) +response = translate(llm, description, interactive=true) ``` The `translate` function will first produce a Markdown representation of the problem, and then return the generated Julia code for parsing the input data and building the model. +The flag `interactive=true` will enable a simple interactive command-line application, where you will be able to inspect, edit and regenerate each intermediate output. + At each generation step, it will prompt the user in an interactive menu to accept the answer, edit the prompt and/or the generated text, or generate another answer with the same prompt. The LLM expects the user to provide examples of the input data format. If no examples are present, the LLM will make assumptions about the data format based on the problem description. diff --git a/src/llm.jl b/src/llm.jl index 83fd23c..1fe86f7 100644 --- a/src/llm.jl +++ b/src/llm.jl @@ -1,44 +1,48 @@ const GROQ_URL::String = "https://api.groq.com/openai/v1/chat/completions" -const GEMINI_URL::String = "https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}:generateContent" -const GEMINI_URL_STREAM::String = "https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}:streamGenerateContent?alt=sse" +const GEMINI_URL::String = "https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}" abstract type AbstractLLM end +abstract type OpenAILLM <: AbstractLLM end """ GroqLLM Structure encapsulating the parameters for accessing the Groq LLM API. -- `api_key`: an API key for accessing the Groq API (https://groq.com), read from the environmental variable GROQ_API_KEY +- `api_key`: an API key for accessing the Groq API (https://groq.com), read from the environmental variable GROQ_API_KEY. - `model_id`: a string identifier for the model to query. See https://console.groq.com/docs/models for the list of available models. +- `url`: URL for chat completions. Defaults to "https://api.groq.com/openai/v1/chat/completions". """ -struct GroqLLM <: AbstractLLM +struct GroqLLM <: OpenAILLM api_key::String model_id::String + url::String - function GroqLLM(model_id::String = "llama-3.1-8b-instant") + function GroqLLM(model_id::String = "llama3-70b-8192", url = GROQ_URL) api_key = get(ENV, "GROQ_API_KEY", "") if isempty(api_key) error("Environment variable GROQ_API_KEY is not set") end - new(api_key, model_id) + new(api_key, model_id, url) end end """ Google LLM Structure encapsulating the parameters for accessing the Google LLM API. -- `api_key`: an API key for accessing the Google Gemini API (https://ai.google.dev/gemini-api/docs/), read from the environmental variable GOOGLE_API_KEY +- `api_key`: an API key for accessing the Google Gemini API (https://ai.google.dev/gemini-api/docs/), read from the environmental variable GOOGLE_API_KEY. - `model_id`: a string identifier for the model to query. See https://ai.google.dev/gemini-api/docs/models/gemini for the list of available models. +- `url`: URL for chat completions. Defaults to ""https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}". """ struct GoogleLLM <: AbstractLLM api_key::String model_id::String + url::String function GoogleLLM(model_id::String = "gemini-1.5-flash") api_key = get(ENV, "GOOGLE_API_KEY", "") if isempty(api_key) error("Environment variable GOOGLE_API_KEY is not set") end - new(api_key, model_id) + new(api_key, model_id, GEMINI_URL) end end @@ -46,25 +50,27 @@ end LlamaCppLLM Structure encapsulating the parameters for accessing the llama.cpp server API. - `api_key`: an optional API key for accessing the server +- `model_id`: a string identifier for the model to query. Unused, kept for API compatibility. - `url`: the URL of the llama.cpp server OpenAI API endpoint (e.g., http://localhost:8080) NOTE: we do not apply the appropriate chat templates to the prompt. This must be handled either in an external code path or by the server. """ -struct LlamaCppLLM <: AbstractLLM +struct LlamaCppLLM <: OpenAILLM api_key::String + model_id::String url::String function LlamaCppLLM(url::String) api_key = get(ENV, "LLAMA_CPP_API_KEY", "no-key") - new(api_key, url) + new(api_key, "hal-9000-v2", url) end end """ - get_completion(llm::GroqLLM, prompt::Prompt) -Returns a completion for the given prompt using the Groq LLM API. + get_completion(llm::OpenAILLM, prompt::Prompt) +Returns a completion for the given prompt using an OpenAI API compatible LLM """ -function get_completion(llm::GroqLLM, prompt::Prompt) +function get_completion(llm::OpenAILLM, prompt::Prompt) headers = [ "Authorization" => "Bearer $(llm.api_key)", "Content-Type" => "application/json", @@ -76,7 +82,7 @@ function get_completion(llm::GroqLLM, prompt::Prompt) ], "model" => llm.model_id, )) - response = HTTP.post(GROQ_URL, headers, body) + response = HTTP.post(llm.url, headers, body) body = JSON3.read(response.body) return body["choices"][1]["message"]["content"] end @@ -86,7 +92,8 @@ end Returns a completion for the given prompt using the Google Gemini LLM API. """ function get_completion(llm::GoogleLLM, prompt::Prompt) - url = replace(GEMINI_URL, "{{model_id}}" => llm.model_id) + url = replace(llm.url, "{{model_id}}" => llm.model_id) + url *= ":generateContent" headers = [ "x-goog-api-key" => "$(llm.api_key)", "Content-Type" => "application/json", @@ -102,85 +109,11 @@ function get_completion(llm::GoogleLLM, prompt::Prompt) end """ - get_completion(llm::LlamaCppLLM, prompt::Prompt) -Returns a completion for the given prompt using the llama.cpp server API. -""" -function get_completion(llm::LlamaCppLLM, prompt::Prompt) - url = join([llm.url, "v1/chat/completions"], "/") - header = [ - "Authorization" => "Bearer $(llm.api_key)", - "Content-Type" => "application/json", - ] - body = JSON3.write(Dict( - "messages" => [ - Dict("role" => "system", "content" => prompt.system), - Dict("role" => "user", "content" => prompt.user), - ], - )) - response = HTTP.post(url, header, body) - body = JSON3.read(response.body) - return body["choices"][1]["message"]["content"] -end - -""" - stream_completion(llm::LlamaCppLLM, prompt::Prompt) -Returns a completion for the given prompt using the Groq LLM API. -The completion is streamed to the terminal as it is generated. -""" -function stream_completion(llm::LlamaCppLLM, prompt::Prompt) - url = join([llm.url, "v1/chat/completions"], "/") - headers = [ - "Authorization" => "Bearer $(llm.api_key)", - "Content-Type" => "application/json", - ] - body = JSON3.write(Dict( - "messages" => [ - Dict("role" => "system", "content" => prompt.system), - Dict("role" => "user", "content" => prompt.user), - ], - "stream" => true, - )) - - accumulated_content = "" - event_buffer = "" - - HTTP.open(:POST, url, headers; body = body) do io - write(io, body) - HTTP.closewrite(io) - HTTP.startread(io) - while !eof(io) - chunk = String(readavailable(io)) - events = split(chunk, "\n\n") - if !endswith(event_buffer, "\n\n") - event_buffer = events[end] - events = events[1:(end - 1)] - else - event_buffer = "" - end - events = join(events, "\n") - for line in eachmatch(r"(?<=data: ).*", events, overlap = true) - if line.match == "[DONE]" - print("\n") - break - end - message = JSON3.read(line.match) - if !isempty(message["choices"][1]["delta"]) - print(message["choices"][1]["delta"]["content"]) - accumulated_content *= message["choices"][1]["delta"]["content"] - end - end - end - HTTP.closeread(io) - end - return accumulated_content -end - -""" - stream_completion(llm::GroqLLM, prompt::Prompt) -Returns a completion for the given prompt using the Groq LLM API. + stream_completion(llm::OpenAILLM, prompt::Prompt) +Returns a completion for the given prompt using an OpenAI API compatible model. The completion is streamed to the terminal as it is generated. """ -function stream_completion(llm::GroqLLM, prompt::Prompt) +function stream_completion(llm::OpenAILLM, prompt::Prompt) headers = [ "Authorization" => "Bearer $(llm.api_key)", "Content-Type" => "application/json", @@ -197,29 +130,32 @@ function stream_completion(llm::GroqLLM, prompt::Prompt) accumulated_content = "" event_buffer = "" - HTTP.open(:POST, GROQ_URL, headers; body = body) do io + HTTP.open(:POST, llm.url, headers; body = body) do io write(io, body) HTTP.closewrite(io) HTTP.startread(io) while !eof(io) chunk = String(readavailable(io)) - events = split(chunk, "\n\n") + event_buffer *= chunk + events = split(event_buffer, "\n\n") if !endswith(event_buffer, "\n\n") event_buffer = events[end] events = events[1:(end - 1)] else event_buffer = "" end - events = join(events, "\n") - for line in eachmatch(r"(?<=data: ).*", events, overlap = true) - if line.match == "[DONE]" - print("\n") - break - end - message = JSON3.read(line.match) - if !isempty(message["choices"][1]["delta"]) - print(message["choices"][1]["delta"]["content"]) - accumulated_content *= message["choices"][1]["delta"]["content"] + + for event in events + for line in eachmatch(r"(?<=data: ).*", event) + if line.match == "[DONE]" + print("\n") + return accumulated_content + end + message = JSON3.read(line.match) + if !isempty(message["choices"][1]["delta"]) + print(message["choices"][1]["delta"]["content"]) + accumulated_content *= message["choices"][1]["delta"]["content"] + end end end end @@ -234,7 +170,8 @@ Returns a completion for the given prompt using the Google Gemini LLM API. The completion is streamed to the terminal as it is generated. """ function stream_completion(llm::GoogleLLM, prompt::Prompt) - url = replace(GEMINI_URL_STREAM, "{{model_id}}" => llm.model_id) + url = replace(llm.url, "{{model_id}}" => llm.model_id) + url *= ":streamGenerateContent?alt=sse" headers = [ "x-goog-api-key" => "$(llm.api_key)", "Content-Type" => "application/json", @@ -253,14 +190,14 @@ function stream_completion(llm::GoogleLLM, prompt::Prompt) HTTP.startread(io) while !eof(io) chunk = String(readavailable(io)) - line = match(r"(?<=data: ).*", chunk) - if isnothing(line) - print("\n") - break + for line in eachmatch(r"(?<=data: ).*", chunk) + if isnothing(line) + continue + end + message = JSON3.read(line.match) + print(message["candidates"][1]["content"]["parts"][1]["text"]) + accumulated_content *= String(message["candidates"][1]["content"]["parts"][1]["text"]) end - message = JSON3.read(line.match) - print(message["candidates"][1]["content"]["parts"][1]["text"]) - accumulated_content *= String(message["candidates"][1]["content"]["parts"][1]["text"]) end HTTP.closeread(io) end diff --git a/src/translate.jl b/src/translate.jl index e88ce17..b7feaef 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -140,7 +140,7 @@ function fix_syntax_errors(model::AbstractLLM, code::AbstractString, error::Abst end """ - translate(model::AbstractLLM, description::AbstractString) + translate(model::AbstractLLM, description::AbstractString; interactive::Bool = false) Translate the natural-language `description` of an optimization problem into a Constraint Programming model by querying the Large Language Model `model`. If `interactive`, the user will be prompted via the command line to inspect the @@ -148,7 +148,7 @@ intermediate outputs of the LLM, and possibly modify them. """ function translate( model::AbstractLLM, - description::AbstractString, + description::AbstractString; interactive::Bool = false, ) constraints = String[] diff --git a/templates/FixJuliaSyntax.json b/templates/FixJuliaSyntax.json index 3ffa472..6cf5674 100644 --- a/templates/FixJuliaSyntax.json +++ b/templates/FixJuliaSyntax.json @@ -7,7 +7,7 @@ "_type": "metadatamessage" }, { - "content": "You are an AI assistant specialized in writing Julia code. Your task is to examine a given code snippet alongside an error message related to syntax errors, and provide an updated version of the code snippet with the syntax errors resolved. \nIMPORTANT: 1. You must only fix the syntax errors without changing the functionality of the code.\n2. Think step-by-step, first describing the syntax errors in a bulleted list, and then providing the corrected code snippet in a Julia code block.\n3. You must report the complete code with the fix.", + "content": "You are an AI assistant specialized in writing Julia code. Your task is to examine a given code snippet alongside an error message related to syntax errors, and provide an updated version of the code snippet with the syntax errors resolved. \nIMPORTANT: 1. You must only fix the syntax errors without changing the functionality of the code.\n2. Think step-by-step, first describing the syntax errors in a bulleted list, and then providing the corrected code snippet in a Julia code block (i.e., ```julia [your code here] ```.\n3. You must report the complete code with the fix.", "variables": [], "_type": "systemmessage" }, diff --git a/templates/JumpifyModel.json b/templates/JumpifyModel.json index de63bda..9e9cc9c 100644 --- a/templates/JumpifyModel.json +++ b/templates/JumpifyModel.json @@ -7,7 +7,7 @@ "_type": "metadatamessage" }, { - "content": "You are an AI assistant specialized in modeling Constraint Programming problems. Your task is to examine a given description of a Constraint Programming model and provide a code implementation in Julia, using JuMP and the CBLS solver. The code MUST: 1) Read the input data from external files into data structures according to the specifications provided in the description, using the appropriate Julia packages (e.g., DataFrames.jl, CSV.jl, etc.), 2) build the model, and 3) return the model.\nConstraints MUST be expressed with the following JuMP syntax: `@constraint(model, x in ConstraintName(kwargs))`, where `x` is a vector of variables, `ConstraintName` is the name of the constraint in camel-case (example: all different constraint -> AllDifferent()), and `kwargs` are the keyword arguments for the constraint (example: Sum(op=<=, val=10).\nIMPORTANT: 1. Output only the required function with no additional text or usage examples.\n2. You must write a docstring for the code.\n3. The code must be succinct and capture all the described constraints.\n4. You MUST use the provide syntax to express constraints. Do NOT express constraints in algebraic form.\n\n{{examples}}", + "content": "You are an AI assistant specialized in modeling Constraint Programming problems. Your task is to examine a given description of a Constraint Programming model and provide a code implementation in Julia, using JuMP and the CBLS solver. The code MUST: 1) Read the input data from external files into data structures according to the specifications provided in the description, using the appropriate Julia packages (e.g., DataFrames.jl, CSV.jl, etc.), 2) build the model, and 3) return the model.\nConstraints MUST be expressed with the following JuMP syntax: `@constraint(model, x in ConstraintName(kwargs))`, where `x` is a vector of variables, `ConstraintName` is the name of the constraint in camel-case (example: all different constraint -> AllDifferent()), and `kwargs` are the keyword arguments for the constraint (example: Sum(op=<=, val=10).\nIMPORTANT: 1. Output only the required function with no additional text or usage examples. The code must be wrapped in a Julia code block (i.e., ```julia [your code here] ```).\n2. You must write a docstring for the code.\n3. The code must be succinct and capture all the described constraints.\n4. You MUST use the provide syntax to express constraints. Do NOT express constraints in algebraic form.\n\n{{examples}}", "variables": [ "examples" ], diff --git a/test/JET.jl b/test/JET.jl index a030a2a..69b9600 100644 --- a/test/JET.jl +++ b/test/JET.jl @@ -1,3 +1,5 @@ @testset "Code linting (JET.jl)" begin - JET.test_package(ConstraintsTranslator; target_defined_modules = true) -end \ No newline at end of file + if VERSION ≤ v"1.10" + JET.test_package(ConstraintsTranslator; target_defined_modules = true) + end +end