Skip to content

Commit

Permalink
Implementation of non-interactive pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoladicicco committed Sep 20, 2024
1 parent 78be193 commit 943e555
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 74 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ InteractiveUtils = "1"
JET = "0.9"
JSON3 = "1"
JSONSchema = "1"
REPL = "1"
Test = "1"
TestItemRunner = "1"
TestItems = "1"
Expand Down
3 changes: 2 additions & 1 deletion src/ConstraintsTranslator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import InteractiveUtils
import JSONSchema
import JSON3
import REPL
using REPL.TerminalMenus
import REPL.TerminalMenus: RadioMenu, request
import TestItems: @testitem

# Exports
Expand All @@ -30,5 +30,6 @@ include("template.jl")
include("llm.jl")
include("parsing.jl")
include("translate.jl")
include("utils.jl")

end
2 changes: 1 addition & 1 deletion src/llm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct GoogleLLM <: AbstractLLM
api_key::String
model_id::String

function GoogleLLM(model_id::String = "gemini-1.5-pro")
function GoogleLLM(model_id::String = "gemini-1.5-flash")
api_key = get(ENV, "GOOGLE_API_KEY", "")
if isempty(api_key)
error("Environment variable GOOGLE_API_KEY is not set")
Expand Down
28 changes: 23 additions & 5 deletions src/parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,35 @@ function parse_code(s::String)
# Extract the code blocks and their language annotations
for m in matches
lang = m.captures[1] == "" ? "plain" : m.captures[1]
code = strip(m.captures[2])
if haskey(code_dict, lang)
code_dict[lang] *= "\n" * code
else
code_dict[lang] = code
code = m.captures[2]
if !isnothing(code)
code = strip(code)
if haskey(code_dict, lang)
code_dict[lang] *= "\n" * code
else
code_dict[lang] = code
end
end
end

return code_dict
end

"""
check_syntax_errors(s::String)
Parses the string `s` as Julia code. In the case of syntax errors, it returns the error
message of the parser as a string. Otherwise, it returns an empty string.
"""
function check_syntax_errors(s::String)
parsed_expr = Meta.parse(s, raise = false)
error_message = ""
if parsed_expr.head == :incomplete || parsed_expr.head == :error
parse_error = parsed_expr.args[1]
error_message = string(parse_error)
end
return error_message
end

"""
edit_in_vim(s::String)
Edits the input string `s` in a temporary file using the Vim editor.
Expand Down
2 changes: 1 addition & 1 deletion src/template.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function read_template(data_path::String)
file_content = read(data_path, String)
data = JSON3.read(file_content)

package_path = pkgdir(@__MODULE__)
package_path = get_package_path()
schema_path = joinpath(package_path, "templates", "TemplateSchema.json")
schema_content = read(schema_path, String)
schema = JSONSchema.Schema(JSON3.read(schema_content))
Expand Down
155 changes: 91 additions & 64 deletions src/translate.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const MAX_RETRIES::Int = 3

"""
extract_structure(model <: AbstractLLM, description <: AbstractString)
Extracts the parameters, decision variables and constraints of an optimization problem
Expand All @@ -8,36 +10,40 @@ function extract_structure(
model::AbstractLLM,
description::AbstractString,
constraints::AbstractString,
interactive::Bool,
)
package_path::String = pkgdir(@__MODULE__)
template_path = joinpath(package_path, "templates", "ExtractStructure.json")
template = read_template(template_path)
prompt = format_template(template; description, constraints)
package_path = get_package_path()
prompt_template_path = joinpath(package_path, "templates", "ExtractStructure.json")
prompt_template = read_template(prompt_template_path)

prompt = format_template(prompt_template; description, constraints)
response = stream_completion(model, prompt)

options = [
"Accept the response",
"Edit the response",
"Try again with a different prompt",
"Try again with the same prompt",
]
menu = RadioMenu(options; pagesize = 4)
if interactive
options = [
"Accept the response",
"Edit the response",
"Try again with a different prompt",
"Try again with the same prompt",
]
menu = RadioMenu(options; pagesize = 5)

while true
choice = request("What do you want to do?", menu)
if choice == 1
break
elseif choice == 2
response = edit_in_editor(response)
println(response)
elseif choice == 3
description = edit_in_editor(description)
prompt = format_template(template; description, constraints)
response = stream_completion(model, prompt)
elseif choice == 4
response = stream_completion(model, prompt)
elseif choice == -1
InterruptException()
while true
choice = request("What do you want to do?", menu)
if choice == 1
break
elseif choice == 2
response = edit_in_editor(response)
println(response)
elseif choice == 3
description = edit_in_editor(description)
prompt = format_template(prompt_template; description, constraints)
response = stream_completion(model, prompt)
elseif choice == 4
response = stream_completion(model, prompt)
elseif choice == -1
InterruptException()
end
end
end
return response
Expand All @@ -56,48 +62,63 @@ function jumpify_model(
model::AbstractLLM,
description::AbstractString,
examples::AbstractString,
interactive::Bool,
)
package_path::String = pkgdir(@__MODULE__)
package_path = get_package_path()
template_path = joinpath(package_path, "templates", "JumpifyModel.json")
template = read_template(template_path)
prompt = format_template(template; description, examples)
response = stream_completion(model, prompt)

while true
code = parse_code(response)["julia"]
parsed_expr = Meta.parse(code, raise = false)
error_message = ""
if parsed_expr.head == :incomplete || parsed_expr.head == :error
parse_error = parsed_expr.args[1]
error_message = string(parse_error)
if interactive
while true
code = parse_code(response)["julia"]
error_message = check_syntax_errors(code)

options = [
"Accept the response",
"Edit the response",
"Try again with a different prompt",
"Try again with the same prompt",
]
if !isempty(error_message)
@warn "The generated Julia code has one or more syntax errors!"
push!(options, "Fix syntax errors")
end
menu = RadioMenu(options; pagesize = 5)

choice = request("What do you want to do?", menu)
if choice == 1
break
elseif choice == 2
response = edit_in_editor(response)
println(response)
elseif choice == 3
description = edit_in_editor(description)
prompt = format_template(template; description, examples)
response = stream_completion(model, prompt)
elseif choice == 4
response = stream_completion(model, prompt)
elseif choice == 5
response = fix_syntax_errors(model, code, error_message)
elseif choice == -1
InterruptException()
end
end
options = [
"Accept the response",
"Edit the response",
"Try again with a different prompt",
"Try again with the same prompt",
]
else
code = parse_code(response)["julia"]
error_message = check_syntax_errors(code)
if !isempty(error_message)
@warn "The generated Julia code has one or more syntax errors!"
push!(options, "Fix syntax errors")
end
menu = RadioMenu(options; pagesize = 5)
choice = request("What do you want to do?", menu)
if choice == 1
break
elseif choice == 2
response = edit_in_editor(response)
println(response)
elseif choice == 3
description = edit_in_editor(description)
prompt = format_template(template; description, examples)
response = stream_completion(model, prompt)
elseif choice == 4
response = stream_completion(model, prompt)
elseif choice == 5
response = fix_syntax_errors(model, code, error_message)
elseif choice == -1
InterruptException()
for _ in 1:MAX_RETRIES
response = fix_syntax_errors(model, code, error_message)
code = parse_code(response)["julia"]
error_message = check_syntax_errors(code)
if isempty(error_message)
break
end
@warn "The generated Julia code has one or more syntax errors!"
end
end
end
return response
Expand All @@ -110,7 +131,7 @@ an `error` produced by the Julia parser.
Returns Markdown-formatted text containing the corrected code in a Julia code block.
"""
function fix_syntax_errors(model::AbstractLLM, code::AbstractString, error::AbstractString)
package_path::String = pkgdir(@__MODULE__)
package_path = get_package_path()
template_path = joinpath(package_path, "templates", "FixJuliaSyntax.json")
template = read_template(template_path)
prompt = format_template(template; code, error)
Expand All @@ -122,17 +143,23 @@ end
translate(model::AbstractLLM, description::AbstractString)
Translate the natural-language `description` of an optimization problem into
a Constraint Programming model by querying the Large Language Model `model`.
If `interactive`, the user will be prompted via the command line to inspect the
intermediate outputs of the LLM, and possibly modify them.
"""
function translate(model::AbstractLLM, description::AbstractString)
function translate(
model::AbstractLLM,
description::AbstractString,
interactive::Bool = false,
)
constraints = String[]
for (name, cons) in USUAL_CONSTRAINTS
push!(constraints, "$(name): $(lstrip(cons.description))")
end
constraints = join(constraints, "\n")

structure = extract_structure(model, description, constraints)
structure = extract_structure(model, description, constraints, interactive)

package_path::String = pkgdir(@__MODULE__)
package_path = get_package_path()
examples_path = joinpath(package_path, "examples")
examples_files = filter(x -> endswith(x, ".md"), readdir(examples_path))
examples = []
Expand All @@ -142,7 +169,7 @@ function translate(model::AbstractLLM, description::AbstractString)
end
examples = join(examples, "\n")

response = jumpify_model(model, structure, examples)
response = jumpify_model(model, structure, examples, interactive)

return parse_code(response)["julia"]
end
11 changes: 11 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
get_package_path()
Returns the absolute path of the root directory of `ConstraintsTranslator.jl`.
"""
function get_package_path()
package_path = pkgdir(@__MODULE__)
if isnothing(package_path)
error("The path of the package could not be found. This should never happen!")
end
return package_path
end
2 changes: 1 addition & 1 deletion templates/FixJuliaSyntax.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"_type": "metadatamessage"
},
{
"content": "You are an AI assistant specialized in writing Julia code. Your task is to examine a given code snippet alongside an error message related to syntax errors, and provide an updated version of the code snippet with the syntax errors resolved.\nIMPORTANT: 1. You must only fix the syntax errors without changing the functionality of the code. 2. Think step-by-step, first describing the syntax errors in a bulleted list, and then providing the corrected code snippet in a Julia code block.",
"content": "You are an AI assistant specialized in writing Julia code. Your task is to examine a given code snippet alongside an error message related to syntax errors, and provide an updated version of the code snippet with the syntax errors resolved. \nIMPORTANT: 1. You must only fix the syntax errors without changing the functionality of the code.\n2. Think step-by-step, first describing the syntax errors in a bulleted list, and then providing the corrected code snippet in a Julia code block.\n3. You must report the complete code with the fix.",
"variables": [],
"_type": "systemmessage"
},
Expand Down
8 changes: 7 additions & 1 deletion test/Aqua.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
@testset "Code quality (Aqua.jl)" begin
Aqua.test_all(ConstraintsTranslator)
Aqua.test_all(
ConstraintsTranslator,
ambiguities = (broken = true,),
deps_compat = false,
piracies = (broken = false,),
unbound_args = (broken = false),
)
end

0 comments on commit 943e555

Please sign in to comment.