Skip to content

Commit

Permalink
Update formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
egor-bogomolov committed Jun 5, 2024
1 parent 12a2bd3 commit f5590ee
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 46 deletions.
88 changes: 87 additions & 1 deletion library_based_code_generation/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions library_based_code_generation/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,19 @@ jupyter = "^1.0.0"
notebook = "^7.2.0"
tree-sitter-python = "^0.21.0"
rank-bm25 = "^0.2.2"
black = "^24.4.2"
isort = "^5.13.2"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.black]
line-length = 120
target-version = ["py310"]

[tool.isort]
line_length = 120
py_version = 310
profile = "black"
10 changes: 3 additions & 7 deletions library_based_code_generation/src/context/parsed_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,7 @@ def colored_code(self, other_identifiers: set):

@staticmethod
def filter_function_names(names):
return {
name
for name in names
if not (name.startswith("__") and name.endswith("__")) and not name == "super"
}
return {name for name in names if not (name.startswith("__") and name.endswith("__")) and not name == "super"}

def clean_comments(self):
start1 = bytes("'''", "utf8")
Expand All @@ -163,9 +159,9 @@ def clean_comments(self):
comment_nodes = []

def walk(node):
if 'comment' in node.type.lower():
if "comment" in node.type.lower():
comment_nodes.append(node)
elif node.type == 'string' and (node.text.startswith(start1) or node.text.startswith(start2)):
elif node.type == "string" and (node.text.startswith(start1) or node.text.startswith(start2)):
comment_nodes.append(node)
else:
for child in node.children:
Expand Down
14 changes: 5 additions & 9 deletions library_based_code_generation/src/context/parsed_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,8 @@ def __init__(self, project_root: str, skip_directories: list[str] = None):
parsed_file = ParsedFile(filepath)
self.parsed_files.append(parsed_file)

self.defined_functions = set(chain.from_iterable(
parsed_file.function_names
for parsed_file in self.parsed_files
))

self.defined_classes = set(chain.from_iterable(
parsed_file.class_names
for parsed_file in self.parsed_files
))
self.defined_functions = set(
chain.from_iterable(parsed_file.function_names for parsed_file in self.parsed_files)
)

self.defined_classes = set(chain.from_iterable(parsed_file.class_names for parsed_file in self.parsed_files))
8 changes: 4 additions & 4 deletions library_based_code_generation/src/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
import os
import pickle
from collections import defaultdict
from datasets import load_dataset

import numpy as np
from datasets import load_dataset
from tqdm import tqdm

from ..context.parsed_file import ParsedFile
from ..metrics.chrf import ChrF
from ..metrics.metric import Metric
from ..metrics.overlap import Overlap
from ..models.openai_model import OpenAIModel
from ..models.example_generation_model import ExampleGenerationModel
from ..models.openai_model import OpenAIModel
from ..models.together_model import TogetherModel
from ..context.parsed_file import ParsedFile


def extract_code(message):
Expand Down Expand Up @@ -65,7 +65,7 @@ def evaluate(model: ExampleGenerationModel, metrics: list[Metric], data_path: st
}
for metric in metrics
},
"name": model.name()
"name": model.name(),
}
with open(metadata_path, "w") as fout:
json.dump(metadata, fout)
Expand Down
2 changes: 1 addition & 1 deletion library_based_code_generation/src/metrics/overlap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .metric import Metric
from ..context.parsed_file import ParsedFile
from .metric import Metric


class Overlap(Metric):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from abc import ABC, abstractmethod

import numpy as np
from rank_bm25 import BM25Okapi

from .utils import split_identifier
import numpy as np


class ExampleGenerationModel(ABC):
@abstractmethod
Expand Down Expand Up @@ -30,9 +33,9 @@ def get_bm25_prompt(self, instruction: str, project_apis: list[str], n_selection
predictions.append(project_apis[ind])

bm25_instruction = (
instruction +
"\n\n" +
"You can find the following APIs from the library helpful:\n" +
", ".join(predictions)
instruction
+ "\n\n"
+ "You can find the following APIs from the library helpful:\n"
+ ", ".join(predictions)
)
return self.get_prompt(bm25_instruction)
11 changes: 5 additions & 6 deletions library_based_code_generation/src/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ def __init__(self, model_name: str, use_bm25: bool = False):
self.use_bm25 = use_bm25

def generate(self, task_description: str, project_apis: list[str] = None) -> str:
instruction = self.get_prompt(task_description) \
if not self.use_bm25 \
instruction = (
self.get_prompt(task_description)
if not self.use_bm25
else self.get_bm25_prompt(task_description, project_apis)
)

prompt = [
{
"role": "user",
"content": instruction
},
{"role": "user", "content": instruction},
]
response = self.client.chat.completions.create(
model=self.model_name,
Expand Down
11 changes: 5 additions & 6 deletions library_based_code_generation/src/models/together_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ def __init__(self, model_name: str, use_bm25: bool = False):
self.use_bm25 = use_bm25

def generate(self, task_description: str, project_apis: list[str] = None) -> str:
instruction = self.get_prompt(task_description) \
if not self.use_bm25 \
instruction = (
self.get_prompt(task_description)
if not self.use_bm25
else self.get_bm25_prompt(task_description, project_apis)
)

prompt = [
{
"role": "user",
"content": instruction
},
{"role": "user", "content": instruction},
]
response = self.client.chat.completions.create(
model=self.model_name,
Expand Down
11 changes: 4 additions & 7 deletions library_based_code_generation/src/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@


def camel_case_split(identifier):
matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier)
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
return [m.group(0) for m in matches]


def snake_case_split(identifier):
return identifier.split("_")


def split_identifier(identifier):
parts = [
p.lower()
for part in snake_case_split(identifier)
for p in camel_case_split(part)
if p != ""
]
parts = [p.lower() for part in snake_case_split(identifier) for p in camel_case_split(part) if p != ""]
return parts

0 comments on commit f5590ee

Please sign in to comment.