Skip to content

Commit

Permalink
switch to ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
CalebCourier committed May 17, 2024
1 parent d8d22d2 commit 6dbe917
Show file tree
Hide file tree
Showing 29 changed files with 453 additions and 443 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pr_qa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
- name: Quality Checks
run: |
python -m pip install --upgrade pip;
make install-lock
make build;
make install-lock;
opentelemetry-bootstrap -a install;
make qa;
##
Expand Down
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ refresh:


format:
black -l 80 ./src app.py wsgi.py
ruff check app.py wsgi.py src/ tests/ --fix
ruff format app.py wsgi.py src/ tests/


lint:
flake8 --count ./src app.py wsgi.py
ruff check app.py wsgi.py src/ tests/
ruff format app.py wsgi.py src/ tests/

qa:
make lint
Expand Down
12 changes: 3 additions & 9 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,19 @@ def __init__(self, app):
def __call__(self, environ, start_response):
self_endpoint = os.environ.get("SELF_ENDPOINT", "http://localhost:8000")
url = urlparse(self_endpoint)
environ['wsgi.url_scheme'] = url.scheme
environ["wsgi.url_scheme"] = url.scheme
return self.app(environ, start_response)


def create_app():
app = Flask(__name__)

app.config['APPLICATION_ROOT'] = '/'
app.config["APPLICATION_ROOT"] = "/"
app.config["PREFERRED_URL_SCHEME"] = "https"
app.wsgi_app = ReverseProxied(app.wsgi_app)
CORS(app)

app.wsgi_app = ProxyFix(
app.wsgi_app,
x_for=1,
x_proto=1,
x_host=1,
x_port=1
)
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1)

guardrails_log_level = os.environ.get("GUARDRAILS_LOG_LEVEL", "INFO")
configure_logging(log_level=guardrails_log_level)
Expand Down
3 changes: 1 addition & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
black
flake8
ruff
pytest
coverage
pytest-mock
10 changes: 5 additions & 5 deletions src/blueprints/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def guard(guard_name: str):
405,
"Method Not Allowed",
"/guard/<guard_name> only supports the GET, PUT, and DELETE methods."
" You specified {request_method}".format(
request_method=request.method
),
" You specified {request_method}".format(request_method=request.method),
)


Expand Down Expand Up @@ -151,7 +149,7 @@ def validate(guard_name: str):
result.validation_passed,
result.validated_output,
guard.history,
result.raw_llm_output
result.raw_llm_output,
)

prompt = guard.history.last.inputs.prompt
Expand All @@ -175,7 +173,9 @@ def validate(guard_name: str):
)
validate_span.set_attribute("validated_output", valid_output)

validate_span.set_attribute("tokens_consumed", guard.history.last.tokens_consumed)
validate_span.set_attribute(
"tokens_consumed", guard.history.last.tokens_consumed
)

num_of_reasks = (
guard.history.last.iterations.length - 1
Expand Down
2 changes: 1 addition & 1 deletion src/blueprints/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def health_check():
def api_docs():
global cached_api_spec
if not cached_api_spec:
with open('./open-api-spec.json') as api_spec_file:
with open("./open-api-spec.json") as api_spec_file:
cached_api_spec = json.loads(api_spec_file.read())
return json.dumps(cached_api_spec)

Expand Down
46 changes: 14 additions & 32 deletions src/classes/data_type_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,9 @@ def from_dict(cls, data_type: dict):
class_children = {}
elem_type = element["type"] if element is not None else None
elem_is_list = elem_type == "list"
child_entries = (
children.get("item", {}) if elem_is_list else children
)
child_entries = children.get("item", {}) if elem_is_list else children
for child_key in child_entries:
class_children[child_key] = cls.from_dict(
child_entries[child_key]
)
class_children[child_key] = cls.from_dict(child_entries[child_key])
children_data_types = (
{"item": class_children} if elem_is_list else class_children
)
Expand All @@ -155,13 +151,9 @@ def to_dict(self):
self.children.get("item", {}) if elem_is_list else self.children
)
for child_key in child_entries:
serialized_children[child_key] = child_entries[
child_key
].to_dict()
serialized_children[child_key] = child_entries[child_key].to_dict()
response["children"] = (
{"item": serialized_children}
if elem_is_list
else serialized_children
{"item": serialized_children} if elem_is_list else serialized_children
)

if self.plugins is not None:
Expand All @@ -174,14 +166,16 @@ def from_request(cls, data_type: dict):
if data_type:
children, formatters, element, plugins = pluck(
data_type, ["children", "formatters", "element", "plugins"]
)
)
children_data_types = None
if children:
class_children = {}
elem_type = element.get("type") if element is not None else None
elem_is_list = elem_type == "list"
child_entries = (
children.get("item", {}).get("children", {}) if elem_is_list else children
children.get("item", {}).get("children", {})
if elem_is_list
else children
)
for child_key in child_entries:
class_children[child_key] = cls.from_request(
Expand Down Expand Up @@ -211,13 +205,9 @@ def to_response(self):
self.children.get("item", {}) if elem_is_list else self.children
)
for child_key in child_entries:
serialized_children[child_key] = child_entries[
child_key
].to_response()
serialized_children[child_key] = child_entries[child_key].to_response()
response["children"] = (
{"item": serialized_children}
if elem_is_list
else serialized_children
{"item": serialized_children} if elem_is_list else serialized_children
)

if self.plugins is not None:
Expand Down Expand Up @@ -255,9 +245,7 @@ def from_xml(cls, elem: _Element):

return cls(children, formatters, element, plugins)

def to_xml(
self, parent: _Element, as_parent: Optional[bool] = False
) -> _Element:
def to_xml(self, parent: _Element, as_parent: Optional[bool] = False) -> _Element:
element = None
if as_parent:
element = parent
Expand All @@ -267,9 +255,7 @@ def to_xml(
else:
element = self.element.to_element()

format = (
"; ".join(self.formatters) if len(self.formatters) > 0 else None
)
format = "; ".join(self.formatters) if len(self.formatters) > 0 else None
if format is not None:
element.attrib["format"] = format

Expand All @@ -278,7 +264,6 @@ def to_xml(
if plugins is not None:
element.attrib["plugins"] = plugins


stringified_attribs = {}
for k, v in element.attrib.items():
stringified_attribs[k] = str(v)
Expand All @@ -295,8 +280,7 @@ def to_xml(
)
_parent = xml_data_type
if self_is_list and (
len(child_entries) > 0
or child_entries[0].element.name is not None
len(child_entries) > 0 or child_entries[0].element.name is not None
):
_parent = SubElement(xml_data_type, "object")
for child_key in child_entries:
Expand All @@ -309,9 +293,7 @@ def get_all_plugins(self) -> List[str]:

self_is_list = self.element.type == "list"
if self.children is not None:
children = (
self.children.get("item", {}) if self_is_list else self.children
)
children = self.children.get("item", {}) if self_is_list else self.children
for child_key in children:
plugins.extend(children[child_key].get_all_plugins())
return plugins
Expand Down
10 changes: 3 additions & 7 deletions src/classes/guard_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
name: str,
railspec: RailSpecStruct,
num_reasks: int = None,
description: str = None
description: str = None,
# base_model: dict = None,
):
self.name = name
Expand Down Expand Up @@ -52,9 +52,7 @@ def from_dict(cls, guard: dict):
name, railspec, num_reasks, description = pluck(
guard, ["name", "railspec", "num_reasks", "description"]
)
return cls(
name, RailSpecStruct.from_dict(railspec), num_reasks, description
)
return cls(name, RailSpecStruct.from_dict(railspec), num_reasks, description)

def to_dict(self) -> dict:
return {
Expand Down Expand Up @@ -100,6 +98,4 @@ def from_railspec(
num_reasks: int = None,
description: str = None,
):
return cls(
name, RailSpecStruct.from_xml(railspec), num_reasks, description
)
return cls(name, RailSpecStruct.from_xml(railspec), num_reasks, description)
17 changes: 4 additions & 13 deletions src/classes/rail_spec_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ def from_rail(cls, rail: Rail):
)

def to_rail(self) -> Rail:
input_schema = (
self.input_schema.to_schema() if self.input_schema else None
)
output_schema = (
self.output_schema.to_schema() if self.output_schema else None
)
input_schema = self.input_schema.to_schema() if self.input_schema else None
output_schema = self.output_schema.to_schema() if self.output_schema else None
# TODO: This might not be necessary anymore since we stopped
# BasePrompt from formatting on init
escaped_instructions = escape_curlys(self.instructions)
Expand All @@ -51,9 +47,7 @@ def to_rail(self) -> Rail:
# TODO: This might not be necessary anymore since we stopped
# BasePrompt from formatting on init
escaped_prompt = escape_curlys(self.prompt)
prompt = (
Prompt(escaped_prompt, output_schema) if escaped_prompt else None
)
prompt = Prompt(escaped_prompt, output_schema) if escaped_prompt else None
prompt.source = descape_curlys(prompt.source)
return Rail(
input_schema,
Expand Down Expand Up @@ -148,10 +142,7 @@ def from_xml(cls, railspec: str):
xml_parser = etree.XMLParser(encoding="utf-8")
elem_tree = etree.fromstring(railspec, parser=xml_parser)

if (
"version" not in elem_tree.attrib
or elem_tree.attrib["version"] != "0.1"
):
if "version" not in elem_tree.attrib or elem_tree.attrib["version"] != "0.1":
raise ValueError(
"RAIL file must have a version attribute set to 0.1."
"Change the opening <rail> element to: <rail version='0.1'>."
Expand Down
16 changes: 5 additions & 11 deletions src/classes/schema_element_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def to_element(self) -> ElementStub:
for validator_on_fail in self.on_fails:
validator_tag = validator_on_fail.get("validatorTag", "")
escaped_validator_tag = validator_tag.replace("/", "_")
elem_dict[
f"on-fail-{escaped_validator_tag}"
] = validator_on_fail.get("method")
elem_dict[f"on-fail-{escaped_validator_tag}"] = validator_on_fail.get(
"method"
)
elem_dict.pop("date_format", None)
elem_dict.pop("time_format", None)
elem_dict.pop("on_fails", None)
Expand Down Expand Up @@ -184,11 +184,7 @@ def from_xml(cls, xml: _Element):
strict = None
strict_tag = xml.get("strict", "False")
if strict_tag:
strict = (
True
if strict_tag.lower() == "true"
else False
)
strict = True if strict_tag.lower() == "true" else False
date_format = xml.get("date-format")
time_format = xml.get("time-format")
on_fail = xml.get("on-fail")
Expand All @@ -209,9 +205,7 @@ def from_xml(cls, xml: _Element):
if attr_key.startswith("on-fail") and attr_key != "on-fail":
on_fail_method = xml.get(attr_key)
on_fail_tag = attr_key
on_fails.append(
{"validatorTag": on_fail_tag, "method": on_fail_method}
)
on_fails.append({"validatorTag": on_fail_tag, "method": on_fail_method})
elif attr_key not in handled_keys:
kwargs[attr_key] = xml.get(attr_key)
model = xml.get("model")
Expand Down
16 changes: 4 additions & 12 deletions src/classes/schema_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def from_schema(cls, schema: Schema):
else:
for key in schema.root_datatype:
schema_element = schema.root_datatype[key]
serialized_schema[key] = DataTypeStruct.from_data_type(
schema_element
)
serialized_schema[key] = DataTypeStruct.from_data_type(schema_element)
return cls({"schema": serialized_schema})

def to_schema(self) -> Schema:
Expand All @@ -34,9 +32,7 @@ def to_schema(self) -> Schema:
if isinstance(inner_schema, DataTypeStruct):
string_schema = StringSchema()
string_schema.string_key = inner_schema.element.name
string_schema[
string_schema.string_key
] = inner_schema.to_data_type()
string_schema[string_schema.string_key] = inner_schema.to_data_type()
return string_schema

for key in inner_schema:
Expand All @@ -55,9 +51,7 @@ def from_dict(cls, schema: dict):
else:
for key in inner_schema:
schema_element = inner_schema[key]
serialized_schema[key] = DataTypeStruct.from_dict(
schema_element
)
serialized_schema[key] = DataTypeStruct.from_dict(schema_element)
return cls({"schema": serialized_schema})

def to_dict(self):
Expand All @@ -84,9 +78,7 @@ def from_request(cls, schema: dict):
# JsonSchema
for key in inner_schema:
schema_element = inner_schema[key]
serialized_schema[key] = DataTypeStruct.from_request(
schema_element
)
serialized_schema[key] = DataTypeStruct.from_request(schema_element)
return cls({"schema": serialized_schema})

def to_response(self):
Expand Down
5 changes: 3 additions & 2 deletions src/classes/validation_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(
"instructions": i.inputs.instructions.source
if i.inputs.instructions is not None
else None,
"output": i.outputs.raw_output or i.outputs.llm_response_info.output,
"output": i.outputs.raw_output
or i.outputs.llm_response_info.output,
"parsedOutput": i.parsed_output,
"prompt": {
"source": i.inputs.prompt.source
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(
"propertyPath": fv.property_path,
}
for fv in i.failed_validations
)
),
}
for i in c.iterations
]
Expand Down
Loading

0 comments on commit 6dbe917

Please sign in to comment.