diff --git a/cla.md b/cla.md index 77e33bb..389a474 100644 --- a/cla.md +++ b/cla.md @@ -55,4 +55,4 @@ required to provide support for your Contributions, except to the extent you des You acknowledge that the maintainers of this project are under no obligation to use or incorporate your contributions into the project. The decision to use or incorporate your contributions into the project will be made at the -sole discretion of the maintainers or their authorized delegates. \ No newline at end of file +sole discretion of the maintainers or their authorized delegates. diff --git a/demo_helpers/MANIFEST.in b/demo_helpers/MANIFEST.in new file mode 100644 index 0000000..276bbbd --- /dev/null +++ b/demo_helpers/MANIFEST.in @@ -0,0 +1,3 @@ +include demo_helpers/datasets/README.md +include demo_helpers/pretrained_models/m5.pt +include demo_helpers/pretrained_models/pointnet.pth diff --git a/demo_helpers/demo_helpers/compute_performance.py b/demo_helpers/demo_helpers/compute_performance.py index 8e3bfef..322d687 100644 --- a/demo_helpers/demo_helpers/compute_performance.py +++ b/demo_helpers/demo_helpers/compute_performance.py @@ -182,16 +182,24 @@ def pytorch_model_inference(dataset, model): out = model(**inputs) if not isinstance(out, torch.Tensor): - if "logits" in out: - out = out.logits - elif "start_logits" in out and "end_logits" in out: - out = torch.vstack((out["start_logits"], out["end_logits"])) - elif "last_hidden_state" in out: - out = out.last_hidden_state + if isinstance(out, tuple): + if len(out) == 1: + out = out[0] + else: + raise ValueError("Cannot handle tuple with len", len(out)) + elif isinstance(out, dict): + if "logits" in out: + out = out.logits + elif "start_logits" in out and "end_logits" in out: + out = torch.vstack((out["start_logits"], out["end_logits"])) + elif "last_hidden_state" in out: + out = out.last_hidden_state + else: + raise ValueError( + "Unknown output key. List of keys:", list(out.keys()) + ) else: - raise ValueError( - "Unknown output key. List of keys:", list(out.keys()) - ) + raise ValueError("Unknown output type", type(out)) pred.append(out) return dataset.postprocess(pred) diff --git a/demo_helpers/demo_helpers/dataset.py b/demo_helpers/demo_helpers/dataset.py index 0e7d56c..1bbe307 100644 --- a/demo_helpers/demo_helpers/dataset.py +++ b/demo_helpers/demo_helpers/dataset.py @@ -856,7 +856,7 @@ def _download_dataset(self): def preprocess(self): return [ - {"image_arrays": self.coco[i][0].unsqueeze(0).numpy()} + {"images": self.coco[i][0].unsqueeze(0).numpy()} for i in range(len(self.coco)) ] diff --git a/demo_helpers/demo_helpers/model_download.py b/demo_helpers/demo_helpers/model_download.py index 87e4ab1..a2085be 100644 --- a/demo_helpers/demo_helpers/model_download.py +++ b/demo_helpers/demo_helpers/model_download.py @@ -1,19 +1,23 @@ import os +import zipfile from datasets.utils.file_utils import cached_path from groqflow.common.build import DEFAULT_CACHE_DIR -YOLOV6N_ONNX = "yolov6n_onnx" +YOLOV6N_MODEL = "yolov6n_model" +YOLOV6N_SOURCE = "yolov6n_source" DATA_URLS = { - YOLOV6N_ONNX: "https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n.onnx", + YOLOV6N_MODEL: "https://github.com/meituan/YOLOv6/releases/download/0.4.0/yolov6n.pt", + YOLOV6N_SOURCE: "https://github.com/meituan/YOLOv6/archive/refs/tags/0.4.0.zip", } DST_PATHS = { - YOLOV6N_ONNX: "onnx_models/yolov6n.onnx", + YOLOV6N_MODEL: "pytorch_models/yolov6_nano/yolov6n.pt", + YOLOV6N_SOURCE: "pytorch_models/yolov6_nano/YOLOv6", } @@ -27,3 +31,18 @@ def download_model(model): download_path = cached_path(url) os.symlink(download_path, dst_path) return dst_path + + +def download_source(source): + dst_path = os.path.join(DEFAULT_CACHE_DIR, DST_PATHS[source]) + if os.path.exists(dst_path): + return dst_path + + os.makedirs(os.path.dirname(dst_path), exist_ok=True) + url = DATA_URLS[source] + download_path = cached_path(url) + with zipfile.ZipFile(download_path, "r") as zip_ref: + extracted_dir = os.path.dirname(dst_path) + zip_ref.extractall(extracted_dir) + os.rename(os.path.join(extracted_dir, zip_ref.infolist()[0].filename), dst_path) + return dst_path diff --git a/demo_helpers/demo_helpers/models.py b/demo_helpers/demo_helpers/models.py index 7d3a6f5..60d03fc 100644 --- a/demo_helpers/demo_helpers/models.py +++ b/demo_helpers/demo_helpers/models.py @@ -1,8 +1,18 @@ import os +import subprocess +import sys + import torch import torch.nn as nn import torch.nn.functional as F +from demo_helpers.model_download import ( + YOLOV6N_MODEL, + YOLOV6N_SOURCE, + download_model, + download_source, +) + class M5(nn.Module): def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32): @@ -130,6 +140,33 @@ def forward(self, input): return self.logsoftmax(output) +def get_yolov6n_model(): + weights = download_model(YOLOV6N_MODEL) + source = download_source(YOLOV6N_SOURCE) + export_script = os.path.join(source, "deploy/ONNX/export_onnx.py") + + cmd = [ + sys.executable, + export_script, + "--weights", + weights, + "--img", + "640", + "--batch", + "1", + "--simplify", + ] + p = subprocess.Popen( + cmd, cwd=source, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + p.communicate() + if p.returncode != 0: + raise RuntimeError("Unable to get ONNX model") + + onnx_file = weights.replace(".pt", ".onnx") + return onnx_file + + def load_pretrained(model_name): """Loads a pre-trained model diff --git a/demo_helpers/setup.py b/demo_helpers/setup.py index cb5aeb6..8815f05 100644 --- a/demo_helpers/setup.py +++ b/demo_helpers/setup.py @@ -10,16 +10,16 @@ packages=find_packages( exclude=["*.__pycache__.*"], ), + include_package_data=True, install_requires=[ - "charset-normalizer==2.1.0", - "torch>=1.12.0", + "charset-normalizer==3.3.2", "transformers>=4.20.0", "datasets>=2.3.2", "prettytable>=3.3.0", "wget>=3.2", "setuptools==57.2.0", - "torchvision>=0.11.3", - "torchaudio>=0.12.1", + "torchvision==0.16.0", + "torchaudio==2.1.0", "path>=16.4.0", ], classifiers=[], diff --git a/docs/readme.md b/docs/readme.md index 0c863c4..59432cd 100644 --- a/docs/readme.md +++ b/docs/readme.md @@ -1,4 +1,5 @@ # Documentation + The following are links to GroqFlow documentation: - [Install Guide](install.md): Instructions on how to install GroqFlow. diff --git a/docs/known_issues.md b/docs/release_notes.md similarity index 51% rename from docs/known_issues.md rename to docs/release_notes.md index bba0e42..c5b29f7 100644 --- a/docs/known_issues.md +++ b/docs/release_notes.md @@ -1,7 +1,25 @@ -# GroqFlow Known Issues +# Release Notes + +## v4.3.1 + +### Changes + +* Support for SDK 0.11. +* Add beta support for groq-torch-importer front-end support. +* Clean up package dependencies. +* Various bug fixes. + +### Known Issues + +* Yolo V6 proof point downloads the pytorch weights and invokes the export script to get the ONNX file. +* Pip install of GroqFlow may complain about incompatible protobuf version. + +## v4.2.1 + +### Known Issues * Runtime errors due to mismatches in tensor sizes may occur even though GroqFlow checks the data shape. (G14148) * Whacky terminal line wrapping when printing groqit error messages. (G13235) * GroqFlow requires both the runtime and developer package to be installed. (G18283, G18284) * GroqFlow BERT Quantization Proof Point fails to compile in SDK0.9.3 due to a scheduling error. (G16739) -* Yolo v6 Proof Points fails to run the evaluation after compilation in SDK0.9.2.1. (G18209) \ No newline at end of file +* Yolo v6 Proof Points fails to run the evaluation after compilation in SDK0.9.2.1. (G18209) diff --git a/groqflow/common/build.py b/groqflow/common/build.py index 1c5a0a4..3d1cea5 100644 --- a/groqflow/common/build.py +++ b/groqflow/common/build.py @@ -25,6 +25,7 @@ "dont_use_sdk": "GROQFLOW_BAKE_SDK", "debug": "GROQFLOW_DEBUG", "internal": "GROQFLOW_INTERNAL_FEATURES", + "torch_importer": "GROQFLOW_USE_TORCH_IMPORTER", } # Allow an environment variable to override the default @@ -63,6 +64,13 @@ # the topology argument to groqit(). TOPOLOGY = DRAGONFLY +# Allow users to use the Torch Importer and bypass ONNX. Only applicable for +# Torch models, has no other effect on other model types. +if os.environ.get(environment_variables["torch_importer"]): + USE_TORCH_IMPORTER = True +else: + USE_TORCH_IMPORTER = False + class Backend(enum.Enum): AUTO = "auto" @@ -170,6 +178,9 @@ class GroqInfo(of_build.Info): num_parameters: Optional[int] = None opt_onnx_unsupported_ops: Optional[List[str]] = None opt_onnx_all_ops_supported: Optional[bool] = None + torch_script_exported: Optional[bool] = None + torch_importer_success: Optional[bool] = None + torch_importer_command: Optional[str] = None compiler_success: Optional[bool] = None compiler_command: Optional[str] = None assembler_success: Optional[bool] = None @@ -181,8 +192,8 @@ class GroqInfo(of_build.Info): estimated_pcie_output_latency: Optional[float] = None estimated_throughput: Optional[float] = None estimated_latency: Optional[float] = None - compiled_onnx_input_bytes: Optional[int] = None - compiled_onnx_output_bytes: Optional[int] = None + compiled_model_input_bytes: Optional[int] = None + compiled_model_output_bytes: Optional[int] = None compiler_ram_bytes: Optional[float] = None @@ -233,6 +244,19 @@ def latency_file(self): of_build.output_dir(self.cache_dir, self.config.build_name), "latency.npy" ) + @property + def torch_script_dir(self): + return os.path.join( + of_build.output_dir(self.cache_dir, self.config.build_name), "torchscript" + ) + + @property + def torch_script_file(self): + return os.path.join( + self.torch_script_dir, + f"{self.config.build_name}.pt", + ) + @property def compile_dir(self): return os.path.join( diff --git a/groqflow/common/sdk_helpers.py b/groqflow/common/sdk_helpers.py index f6c3d28..42fc15b 100644 --- a/groqflow/common/sdk_helpers.py +++ b/groqflow/common/sdk_helpers.py @@ -33,7 +33,11 @@ def get_num_chips_available(pci_devices=None): # Capture the list of pci devices on the system using the linux lspci utility if pci_devices is None: - pci_devices = subprocess.check_output([lspci, "-n"]).decode("utf-8").split("\n") + pci_devices = ( + subprocess.check_output([lspci, "-n"], stderr=subprocess.DEVNULL) + .decode("utf-8") + .split("\n") + ) # Unique registered vendor id: 1de0, and device id: "0000" groq_card_id = "1de0:0000" @@ -74,7 +78,11 @@ def _installed_package_version(package: str, os_version: OS) -> Union[bool, str] # Get package info try: cmd = ["apt-cache", "policy", package] - package_info = subprocess.check_output(cmd).decode("utf-8").split("\n") + package_info = ( + subprocess.check_output(cmd, stderr=subprocess.DEVNULL) + .decode("utf-8") + .split("\n") + ) except (FileNotFoundError, subprocess.CalledProcessError) as e: raise exp.Error("apt-cache policy command failed") from e @@ -89,7 +97,11 @@ def _installed_package_version(package: str, os_version: OS) -> Union[bool, str] # Get package info cmd = ["dnf", "info", package] try: - package_info = subprocess.check_output(cmd).decode("utf-8").split("\n") + package_info = ( + subprocess.check_output(cmd, stderr=subprocess.DEVNULL) + .decode("utf-8") + .split("\n") + ) except FileNotFoundError as e: raise exp.Error("dnf info command failed") from e except subprocess.CalledProcessError as e: diff --git a/groqflow/groqmodel/groqmodel.py b/groqflow/groqmodel/groqmodel.py index 4c50c9a..9e07c49 100644 --- a/groqflow/groqmodel/groqmodel.py +++ b/groqflow/groqmodel/groqmodel.py @@ -139,13 +139,13 @@ def estimate_performance(self) -> GroqEstimatedPerformance: # Calculate compute latency and estimate PCIe latency self.state.info.estimated_pcie_input_latency = ( - self.state.info.compiled_onnx_input_bytes / pcie_bandwidth + self.state.info.compiled_model_input_bytes / pcie_bandwidth ) + pcie_latency self.state.info.deterministic_compute_latency = on_chip_compute_cycles / ( frequency ) self.state.info.estimated_pcie_output_latency = ( - self.state.info.compiled_onnx_output_bytes / pcie_bandwidth + self.state.info.compiled_model_output_bytes / pcie_bandwidth ) + pcie_latency # When pipelined, the reported cycle is the duration of a single pipelining stage diff --git a/groqflow/justgroqit/compile.py b/groqflow/justgroqit/compile.py index e619d43..774aebe 100644 --- a/groqflow/justgroqit/compile.py +++ b/groqflow/justgroqit/compile.py @@ -3,6 +3,7 @@ import subprocess import pathlib import onnx +import torch import onnxflow.justbuildit.stage as stage import onnxflow.common.exceptions as exp import onnxflow.common.printing as printing @@ -12,20 +13,7 @@ import groqflow.common.sdk_helpers as sdk -def get_and_analyze_onnx(state: build.GroqState): - # TODO: validate this input - # https://git.groq.io/code/Groq/-/issues/13947 - input_onnx = state.intermediate_results[0] - - ( - state.info.compiled_onnx_input_bytes, - state.info.compiled_onnx_output_bytes, - ) = onnx_helpers.io_bytes(input_onnx) - - # Count the number of trained model parameters - onnx_model = onnx.load(input_onnx) - state.info.num_parameters = int(onnx_helpers.parameter_count(onnx_model)) - +def analyze_parameters(state: build.GroqState): # Automatically define the number of chips if num_chips is not provided if state.config.num_chips is None: state.num_chips_used = build.calculate_num_chips(state.info.num_parameters) @@ -45,21 +33,68 @@ def get_and_analyze_onnx(state: build.GroqState): """ raise exp.StageError(msg) - return input_onnx +def analyze_onnx(state: build.GroqState): + # TODO: validate this input + # https://git.groq.io/code/Groq/-/issues/13947 + input_onnx = state.intermediate_results[0] -class CompileOnnx(stage.Stage): - """ - Stage that takes an ONNX file and compiles it into one or more - Alan Assembly (.aa) files. + ( + state.info.compiled_model_input_bytes, + state.info.compiled_model_output_bytes, + ) = onnx_helpers.io_bytes(input_onnx) - Expected inputs: - - state.intermediate_results contains a single .onnx file + # Count the number of trained model parameters + onnx_model = onnx.load(input_onnx) + state.info.num_parameters = int(onnx_helpers.parameter_count(onnx_model)) - Outputs: - - One or more .aa files - - state.num_chips_used contains the number of chips used by - Groq Compiler + +def analyze_torch_script(state: build.GroqState): + model = torch.jit.load(state.torch_script_file) + state.info.compiled_model_input_bytes = sum( + t.element_size() for t in state.inputs.values() + ) + outputs = model(**state.inputs) + state.info.compiled_model_output_bytes = sum(t.element_size() for t in outputs) + + state.info.num_parameters = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + +def torch_types_to_str(type: torch.dtype): + if type == torch.float16: + return "f16" + elif type == torch.float32: + return "f32" + elif type == torch.float64: + return "f64" + elif type == torch.uint8: + return "ui8" + elif type == torch.bool: + return "i1" + elif type == torch.int8: + return "i8" + elif type == torch.int16: + return "i16" + elif type == torch.int32: + return "i32" + elif type == torch.int64: + return "i64" + elif type == torch.chalf: + return "complex" + elif type == torch.cfloat: + return "complex" + elif type == torch.cdouble: + return "complex" + else: + raise TypeError("Unsupported Torch type", type) + + +class Compile(stage.Stage): + """ + Base class for the Compile stage. self.input_file will be set by the + derived class. """ def __init__(self): @@ -72,7 +107,9 @@ def fire(self, state: build.GroqState): sdk.check_dependencies(require_devtools=True, exception_type=exp.StageError) - input_onnx = get_and_analyze_onnx(state) + analyze_parameters(state) + + input_file = state.intermediate_results[0] # Select either bake or SDK if state.use_sdk: @@ -103,7 +140,7 @@ def fire(self, state: build.GroqState): # Add flags cmd = ( cmd - + [input_onnx] + + [input_file] + state.config.compiler_flags + [ "--save-stats", @@ -186,6 +223,92 @@ def fire(self, state: build.GroqState): return state +class CompileOnnx(Compile): + """ + Stage that takes an ONNX file and compiles it into one or more + Alan Assembly (.aa) files. + + Expected inputs: + - state.intermediate_results contains a single .onnx file + + Outputs: + - One or more .aa files + - state.num_chips_used contains the number of chips used by + Groq Compiler + """ + + def fire(self, state: build.GroqState): + analyze_onnx(state) + return super().fire(state) + + +class CompileTorchScript(Compile): + """ + Stage that takes an TorchScript file and compiles it into GTen. + + Expected inputs: + - state.intermediate_results contains a single .pt file + + Outputs: + - One .mlir file + - state.expected_output_names will contain the output names of the model. + """ + + def fire(self, state: build.GroqState): + analyze_torch_script(state) + + # Select either bake or SDK + if state.use_sdk: + sdk.check_dependencies(require_devtools=True, exception_type=exp.StageError) + cmd = sdk.find_tool("groq-torch-importer") + else: + cmd = ["bake", "r", "//Groq/Compiler/Import/Torch:groq-torch-importer"] + + input_types = [] + for data in state.inputs.values(): + shape = "x".join([str(dim) for dim in data.shape]) + dtype = torch_types_to_str(data.dtype) + input_types.append(f"--input-types={shape}x{dtype}") + + gten_file = os.path.join( + state.compile_dir, + f"{state.config.build_name}.gten.mlir", + ) + + cmd = cmd + [state.torch_script_file] + input_types + ["-o", gten_file] + + # Remove duplicated flags + cmd = sorted(set(cmd), key=cmd.index) + state.info.torch_importer_command = " ".join(cmd) + + printing.logn("Running Groq Torch Importer...") + + with subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) as process: + for line in process.stdout: + printing.logn(line.decode("utf8"), end="") + printing.logn("Groq Torch Importer has exited") + + state.info.torch_importer_success = ( + True if os.path.exists(gten_file) and os.path.isfile(gten_file) else False + ) + + if state.info.torch_importer_success: + state.intermediate_results = [gten_file] + else: + msg = f""" + Attempted to use Groq Torch Importer to import TorchSript model into + Groq's Tensor(GTen) dialect format. However, this operation did not + succeed. Please contact GroqFlow support to determine a path forwards. + More information may be available in the log file at **{self.logfile_path}** + """ + raise exp.StageError(msg) + + # Compile the GTen file + return super().fire(state) + + class Assemble(stage.Stage): """ Stage that takes a list of Alan Assembly (.aa) files and runs Groq diff --git a/groqflow/justgroqit/export.py b/groqflow/justgroqit/export.py index d7e850d..50b46a8 100644 --- a/groqflow/justgroqit/export.py +++ b/groqflow/justgroqit/export.py @@ -1,10 +1,22 @@ +import inspect +import os +import sys +import warnings +import torch import onnxflow.justbuildit.stage as stage import onnxflow.common.exceptions as exp +import onnxflow.common.tensor_helpers as tensor_helpers import groqflow.common.build as build import groqflow.common.onnx_helpers as onnx_helpers import groqflow.common.sdk_helpers as sdk +def _warn_to_stdout(message, category, filename, line_number, _, line): + sys.stdout.write( + warnings.formatwarning(message, category, filename, line_number, line) + ) + + class CheckOnnxCompatibility(stage.Stage): """ Stage that takes an ONNX file, checks whether it is compatible @@ -54,3 +66,104 @@ def fire(self, state: build.GroqState): raise exp.StageError(msg) return state + + +class ExportPytorchToTorchScript(stage.Stage): + """ + Stage that takes a Pytorch module and exports it to TorchScript using + torch.jit API. + + Expected inputs: + - state.model is a torch.nn.Module or torch.jit.ScriptModule + - state.inputs is a dict that represents valid kwargs to the forward + function of state.model + + Outputs: + - A *.pt file that implements state.model given state.inputs + """ + + def __init__(self): + super().__init__( + unique_name="export_pytorch_to_torch_script", + monitor_message="Exporting PyTorch to TorchScript", + ) + + @staticmethod + def _check_model(torch_script_file, success_message, fail_message) -> bool: + if os.path.isfile(torch_script_file): + print(success_message) + return True + else: + print(fail_message) + return False + + def fire(self, state: build.GroqState): + if not isinstance(state.model, (torch.nn.Module, torch.jit.ScriptModule)): + msg = f""" + The current stage (ExportPytorchToTorchScript) is only compatible + with models of type torch.nn.Module or torch.jit.ScriptModule, + however the stage received a model of type {type(state.model)}. + """ + raise exp.StageError(msg) + + if isinstance(state.model, torch.nn.Module): + # Validate user provided args + all_args = list(inspect.signature(state.model.forward).parameters.keys()) + + for inp in list(state.inputs.keys()): + if inp not in all_args: + msg = f""" + Input name {inp} not found in the model's forward method. Available + input names are: {all_args}" + """ + raise ValueError(msg) + + # Send torch export warnings to stdout (and therefore the log file) + # so that they don't fill up the command line + default_warnings = warnings.showwarning + warnings.showwarning = _warn_to_stdout + + # Export the model to TorchScript + jit_module = torch.jit.trace( + state.model, + example_kwarg_inputs=state.inputs, + ) + + # Save model to disk + os.makedirs(state.torch_script_dir, exist_ok=True) + jit_module.save(state.torch_script_file) + + # Save output names to ensure we are preserving the order of the outputs. + # We have to re-load the torchscript module because the output names + # will change during serialization. + loaded_jit_module = torch.jit.load(state.torch_script_file) + state.expected_output_names = [ + output.debugName() for output in loaded_jit_module.graph.outputs() + ] + + # Restore default warnings behavior + warnings.showwarning = default_warnings + + tensor_helpers.save_inputs( + [state.inputs], state.original_inputs_file, downcast=False + ) + + # Check the if the base mode has been exported successfully + success_msg = "\tSuccess exporting model to TorchScript" + fail_msg = "\tFailed exporting model to TorchScript" + state.info.torch_script_exported = self._check_model( + state.torch_script_file, success_msg, fail_msg + ) + + if state.info.torch_script_exported: + state.intermediate_results = [state.torch_script_file] + else: + msg = f""" + Unable to export model to TorchScript using Torch's jit exporter. + We recommend that you modify your model until it is + compatible with this third party software, then re-run. + More information may be available in the log file at **{self.logfile_path}** + """ + raise exp.StageError(msg) + + return state diff --git a/groqflow/justgroqit/groqit.py b/groqflow/justgroqit/groqit.py index bb221fe..302deee 100755 --- a/groqflow/justgroqit/groqit.py +++ b/groqflow/justgroqit/groqit.py @@ -12,7 +12,7 @@ def groqit( model: of_build.UnionValidModelInstanceTypes = None, inputs: Optional[Dict[str, Any]] = None, build_name: Optional[str] = None, - cache_dir: str = build.DEFAULT_CACHE_DIR, + cache_dir: Optional[str] = build.DEFAULT_CACHE_DIR, monitor: bool = True, rebuild: Optional[str] = None, compiler_flags: Optional[List[str]] = None, diff --git a/groqflow/justgroqit/ignition.py b/groqflow/justgroqit/ignition.py index 2b1ba5a..afc4f1f 100644 --- a/groqflow/justgroqit/ignition.py +++ b/groqflow/justgroqit/ignition.py @@ -58,6 +58,25 @@ ], ) + +pytorch_export_torch_script_sequence = stage.Sequence( + "pytorch_export_torch_script", + "Export Pytorch Model to TorchScript", + [ + gf_export.ExportPytorchToTorchScript(), + ], +) + +pytorch_torch_importer_sequence = stage.Sequence( + "pytorch_torch_importer", + "Build PyTorch Model using Torch Importer Front-end", + [ + pytorch_export_torch_script_sequence, + compile.CompileTorchScript(), + compile.Assemble(), + ], +) + default_keras_export_sequence = stage.Sequence( "default_keras_export_sequence", "Exporting Keras Model", @@ -337,6 +356,13 @@ def load_or_make_state( of_build.ModelType.PYTORCH: pytorch_sequence_with_quantization, } +groq_model_type_torch_importer_override_to_sequence = { + of_build.ModelType.PYTORCH: pytorch_torch_importer_sequence, + of_build.ModelType.KERAS: default_keras_sequence, + of_build.ModelType.ONNX_FILE: default_onnx_sequence, + of_build.ModelType.HUMMINGBIRD: default_hummingbird_sequence, +} + def model_intake( user_model, @@ -346,13 +372,17 @@ def model_intake( user_quantization_samples: Optional[Collection] = None, ) -> Tuple[Any, Any, stage.Sequence, of_build.ModelType]: + override_sequence_map = groq_model_type_to_sequence + if build.USE_TORCH_IMPORTER: + override_sequence_map = groq_model_type_torch_importer_override_to_sequence + model, inputs, sequence, model_type = of_ignition.model_intake( user_model=user_model, user_inputs=user_inputs, user_sequence=user_sequence, user_quantization_samples=user_quantization_samples, override_quantization_sequence_map=groq_model_type_to_sequence_with_quantization, - override_sequence_map=groq_model_type_to_sequence, + override_sequence_map=override_sequence_map, ) if "--auto-asm" in config.compiler_flags: diff --git a/groqflow/version.py b/groqflow/version.py index aef46ac..ed48cda 100644 --- a/groqflow/version.py +++ b/groqflow/version.py @@ -1 +1 @@ -__version__ = "4.2.1" +__version__ = "4.3.1" diff --git a/license.md b/license.md index 2ecdd6b..27baae1 100644 --- a/license.md +++ b/license.md @@ -4,4 +4,4 @@ Permission is hereby granted, free of charge, to any person obtaining a copy of The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/proof_points/README.md b/proof_points/README.md index a79f6d8..b7f56dc 100644 --- a/proof_points/README.md +++ b/proof_points/README.md @@ -44,7 +44,7 @@ The following relates the proof point models with the version of the GroqWare Su | [MobileNetV2](computer_vision/mobilenetv2/) | >=0.9.2.1 | [ResNet50](computer_vision/resnet50/) | >=0.9.2.1 | [SqueezeNet](computer_vision/squeezenet/) | >=0.9.2.1 -| [Yolo v6](computer_vision/yolo/) | 0.9.3 +| [Yolo v6](computer_vision/yolo/) | >=0.11.0 ### Natural Language Processing diff --git a/proof_points/computer_vision/deit/deit_tiny.py b/proof_points/computer_vision/deit/deit_tiny.py index b8adee8..d274a51 100644 --- a/proof_points/computer_vision/deit/deit_tiny.py +++ b/proof_points/computer_vision/deit/deit_tiny.py @@ -15,7 +15,9 @@ def evaluate_deit_tiny(rebuild_policy=None, should_execute=True): # load torch model - model = ViTForImageClassification.from_pretrained("facebook/deit-tiny-patch16-224") + model = ViTForImageClassification.from_pretrained( + "facebook/deit-tiny-patch16-224", torchscript=True + ) model.eval() # create dummy inputs to prime groq model @@ -26,12 +28,13 @@ def evaluate_deit_tiny(rebuild_policy=None, should_execute=True): # compute performance on CPU and GroqChip if should_execute: - return compute_performance( + compute_performance( groq_model, model, dataset="sampled_imagenet", task="classification", ) + print(f"Proof point {__file__} finished!") if __name__ == "__main__": diff --git a/proof_points/computer_vision/deit/requirements.txt b/proof_points/computer_vision/deit/requirements.txt index 4b8d1f4..4e5dbc3 100644 --- a/proof_points/computer_vision/deit/requirements.txt +++ b/proof_points/computer_vision/deit/requirements.txt @@ -1,2 +1,2 @@ torch>=1.12.0 -transformers>=4.20.0 \ No newline at end of file +transformers>=4.20.0 diff --git a/proof_points/computer_vision/googlenet/googlenet.py b/proof_points/computer_vision/googlenet/googlenet.py index 2cf3800..b10c57b 100644 --- a/proof_points/computer_vision/googlenet/googlenet.py +++ b/proof_points/computer_vision/googlenet/googlenet.py @@ -40,6 +40,8 @@ def evaluate_googlenet(rebuild_policy=None, should_execute=None): groq_model, torch_model, "sampled_imagenet", task="classification" ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_googlenet(**parse_args()) diff --git a/proof_points/computer_vision/mobilenetv2/mobilenetv2.py b/proof_points/computer_vision/mobilenetv2/mobilenetv2.py index 8cbf42d..8b2153f 100644 --- a/proof_points/computer_vision/mobilenetv2/mobilenetv2.py +++ b/proof_points/computer_vision/mobilenetv2/mobilenetv2.py @@ -43,6 +43,8 @@ def evaluate_mobilenetv2(rebuild_policy=None, should_execute=None): groq_model, torch_model, "sampled_imagenet", task="classification" ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_mobilenetv2(**parse_args()) diff --git a/proof_points/computer_vision/mobilenetv2/requirements.txt b/proof_points/computer_vision/mobilenetv2/requirements.txt index 2ec1224..be222b0 100644 --- a/proof_points/computer_vision/mobilenetv2/requirements.txt +++ b/proof_points/computer_vision/mobilenetv2/requirements.txt @@ -1 +1 @@ -torch>=1.12.0 \ No newline at end of file +torch>=1.12.0 diff --git a/proof_points/computer_vision/resnet50/requirements.txt b/proof_points/computer_vision/resnet50/requirements.txt index 2ec1224..be222b0 100644 --- a/proof_points/computer_vision/resnet50/requirements.txt +++ b/proof_points/computer_vision/resnet50/requirements.txt @@ -1 +1 @@ -torch>=1.12.0 \ No newline at end of file +torch>=1.12.0 diff --git a/proof_points/computer_vision/resnet50/resnet50.py b/proof_points/computer_vision/resnet50/resnet50.py index 7e55db3..8605bbc 100644 --- a/proof_points/computer_vision/resnet50/resnet50.py +++ b/proof_points/computer_vision/resnet50/resnet50.py @@ -27,10 +27,12 @@ def evaluate_resnet50(rebuild_policy=None, should_execute=True): # Execute PyTorch model on CPU, Groq Model and print accuracy if should_execute: - return compute_performance( + compute_performance( groq_model, pytorch_model, "sampled_imagenet", task="classification" ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_resnet50(**parse_args()) diff --git a/proof_points/computer_vision/squeezenet/squeezenet.py b/proof_points/computer_vision/squeezenet/squeezenet.py index 387a328..1a06abc 100644 --- a/proof_points/computer_vision/squeezenet/squeezenet.py +++ b/proof_points/computer_vision/squeezenet/squeezenet.py @@ -40,6 +40,8 @@ def evaluate_squeezenet(rebuild_policy=None, should_execute=None): groq_model, torch_model, "sampled_imagenet", task="classification" ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_squeezenet(**parse_args()) diff --git a/proof_points/computer_vision/yolo/yolov6_nano.py b/proof_points/computer_vision/yolo/yolov6_nano.py index 2656aab..172550c 100644 --- a/proof_points/computer_vision/yolo/yolov6_nano.py +++ b/proof_points/computer_vision/yolo/yolov6_nano.py @@ -4,33 +4,31 @@ the COCO dataset (https://cocodataset.org/) on CPU and GroqChipâ„¢ processor using the GroqFlow toolchain. """ +import torch from groqflow import groqit from demo_helpers.args import parse_args from demo_helpers.compute_performance import compute_performance -from demo_helpers.model_download import YOLOV6N_ONNX, download_model +from demo_helpers.models import get_yolov6n_model from demo_helpers.misc import check_deps -import torch - - -def get_onnx_model(): - return download_model(YOLOV6N_ONNX) - def evaluate_yolov6n(rebuild_policy=None, should_execute=True): check_deps(__file__) - pytorch_model = get_onnx_model() - dummy_inputs = {"image_arrays": torch.ones([1, 3, 640, 640])} + model = get_yolov6n_model() + dummy_inputs = {"images": torch.ones([1, 3, 640, 640])} # Get Groq Model using groqit groq_model = groqit( - pytorch_model, + model, dummy_inputs, rebuild=rebuild_policy, + compiler_flags=["--effort=high"], ) if should_execute: - return compute_performance(groq_model, pytorch_model, "coco", task="coco_map") + compute_performance(groq_model, model, "coco", task="coco_map") + + print(f"Proof point {__file__} finished!") if __name__ == "__main__": diff --git a/proof_points/natural_language_processing/bert/bert_base.py b/proof_points/natural_language_processing/bert/bert_base.py index c37f020..105e7fd 100644 --- a/proof_points/natural_language_processing/bert/bert_base.py +++ b/proof_points/natural_language_processing/bert/bert_base.py @@ -20,7 +20,7 @@ def get_model(): tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name) pytorch_model = transformers.AutoModelForSequenceClassification.from_pretrained( - pretrained_model_name + pretrained_model_name, torchscript=True ) return pytorch_model.eval(), tokenizer @@ -48,7 +48,7 @@ def evaluate_bert(rebuild_policy=None, should_execute=True): # compute performance on CPU and GroqChip if should_execute: - return compute_performance( + compute_performance( groq_model, pytorch_model, dataset="sst", @@ -57,6 +57,8 @@ def evaluate_bert(rebuild_policy=None, should_execute=True): task="classification", ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_bert(**parse_args()) diff --git a/proof_points/natural_language_processing/bert/bert_quantize.py b/proof_points/natural_language_processing/bert/bert_quantize.py index e4fae6a..f247b66 100644 --- a/proof_points/natural_language_processing/bert/bert_quantize.py +++ b/proof_points/natural_language_processing/bert/bert_quantize.py @@ -63,8 +63,7 @@ def evaluate_bert(rebuild_policy=None, should_execute=True): ) if should_execute: - - return compute_performance( + compute_performance( groq_model, pytorch_model, dataset="sst-int32", @@ -73,6 +72,8 @@ def evaluate_bert(rebuild_policy=None, should_execute=True): task="classification", ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_bert(**parse_args()) diff --git a/proof_points/natural_language_processing/bert/bert_tiny.py b/proof_points/natural_language_processing/bert/bert_tiny.py index c7be470..5ee5615 100644 --- a/proof_points/natural_language_processing/bert/bert_tiny.py +++ b/proof_points/natural_language_processing/bert/bert_tiny.py @@ -20,7 +20,7 @@ def get_model(): tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name) pytorch_model = transformers.AutoModelForSequenceClassification.from_pretrained( - pretrained_model_name + pretrained_model_name, torchscript=True ) return pytorch_model.eval(), tokenizer @@ -48,7 +48,7 @@ def evaluate_bert_tiny(rebuild_policy=None, should_execute=True): # compute performance on CPU and GroqChip if should_execute: - return compute_performance( + compute_performance( groq_model, pytorch_model, dataset="sst", @@ -57,6 +57,8 @@ def evaluate_bert_tiny(rebuild_policy=None, should_execute=True): task="classification", ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_bert_tiny(**parse_args()) diff --git a/proof_points/natural_language_processing/bert/requirements.txt b/proof_points/natural_language_processing/bert/requirements.txt index 5babbda..a568b47 100644 --- a/proof_points/natural_language_processing/bert/requirements.txt +++ b/proof_points/natural_language_processing/bert/requirements.txt @@ -1,3 +1,3 @@ -numpy>=1.22.4 +numpy>=1.21.6 torch>=1.12.1 transformers>=4.20.0 diff --git a/proof_points/natural_language_processing/distilbert/distilbert.py b/proof_points/natural_language_processing/distilbert/distilbert.py index 48be522..c2dc94e 100644 --- a/proof_points/natural_language_processing/distilbert/distilbert.py +++ b/proof_points/natural_language_processing/distilbert/distilbert.py @@ -18,7 +18,7 @@ def evaluate_distilbert(rebuild_policy=None, should_execute=True): pretrained_model = "distilbert-base-uncased-finetuned-sst-2-english" tokenizer = AutoTokenizer.from_pretrained(pretrained_model) pytorch_model = DistilBertForSequenceClassification.from_pretrained( - pretrained_model + pretrained_model, torchscript=True ) # dummy inputs to generate the groq model @@ -52,6 +52,8 @@ def evaluate_distilbert(rebuild_policy=None, should_execute=True): task="classification", ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_distilbert(**parse_args()) diff --git a/proof_points/natural_language_processing/distilbert/requirements.txt b/proof_points/natural_language_processing/distilbert/requirements.txt index 4b8d1f4..4e5dbc3 100644 --- a/proof_points/natural_language_processing/distilbert/requirements.txt +++ b/proof_points/natural_language_processing/distilbert/requirements.txt @@ -1,2 +1,2 @@ torch>=1.12.0 -transformers>=4.20.0 \ No newline at end of file +transformers>=4.20.0 diff --git a/proof_points/natural_language_processing/electra/electra.py b/proof_points/natural_language_processing/electra/electra.py index 0694086..68ff88c 100644 --- a/proof_points/natural_language_processing/electra/electra.py +++ b/proof_points/natural_language_processing/electra/electra.py @@ -24,7 +24,7 @@ def evaluate_electra(rebuild_policy=None, should_execute=True): tokenizer = transformers.ElectraTokenizerFast.from_pretrained(pretrained_model_name) pytorch_model = transformers.ElectraForSequenceClassification.from_pretrained( - pretrained_model_name + pretrained_model_name, torchscript=True ) pytorch_model.eval() @@ -41,7 +41,7 @@ def evaluate_electra(rebuild_policy=None, should_execute=True): # compute performance on CPU and GroqChip if should_execute: - return compute_performance( + compute_performance( groq_model, pytorch_model, dataset="sst", @@ -50,6 +50,8 @@ def evaluate_electra(rebuild_policy=None, should_execute=True): task="classification", ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_electra(**parse_args()) diff --git a/proof_points/natural_language_processing/electra/requirements.txt b/proof_points/natural_language_processing/electra/requirements.txt index 4b8d1f4..4e5dbc3 100644 --- a/proof_points/natural_language_processing/electra/requirements.txt +++ b/proof_points/natural_language_processing/electra/requirements.txt @@ -1,2 +1,2 @@ torch>=1.12.0 -transformers>=4.20.0 \ No newline at end of file +transformers>=4.20.0 diff --git a/proof_points/natural_language_processing/minilm/minilmv2.py b/proof_points/natural_language_processing/minilm/minilmv2.py index 0241016..bee136f 100644 --- a/proof_points/natural_language_processing/minilm/minilmv2.py +++ b/proof_points/natural_language_processing/minilm/minilmv2.py @@ -34,7 +34,7 @@ def evaluate_minilm(rebuild_policy=None, should_execute=True): # compute performance on CPU and GroqChip if should_execute: - return compute_performance( + compute_performance( groq_model, model, dataset="stsb_multi_mt", @@ -43,6 +43,8 @@ def evaluate_minilm(rebuild_policy=None, should_execute=True): task="sentence_similarity", ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_minilm(**parse_args()) diff --git a/proof_points/natural_language_processing/minilm/requirements.txt b/proof_points/natural_language_processing/minilm/requirements.txt index 4b8d1f4..4e5dbc3 100644 --- a/proof_points/natural_language_processing/minilm/requirements.txt +++ b/proof_points/natural_language_processing/minilm/requirements.txt @@ -1,2 +1,2 @@ torch>=1.12.0 -transformers>=4.20.0 \ No newline at end of file +transformers>=4.20.0 diff --git a/proof_points/natural_language_processing/roberta/requirements.txt b/proof_points/natural_language_processing/roberta/requirements.txt index 4b8d1f4..4e5dbc3 100644 --- a/proof_points/natural_language_processing/roberta/requirements.txt +++ b/proof_points/natural_language_processing/roberta/requirements.txt @@ -1,2 +1,2 @@ torch>=1.12.0 -transformers>=4.20.0 \ No newline at end of file +transformers>=4.20.0 diff --git a/proof_points/natural_language_processing/roberta/roberta.py b/proof_points/natural_language_processing/roberta/roberta.py index 271beba..4396715 100644 --- a/proof_points/natural_language_processing/roberta/roberta.py +++ b/proof_points/natural_language_processing/roberta/roberta.py @@ -23,7 +23,9 @@ def evaluate_roberta(rebuild_policy=None, should_execute=None): # load pre-trained torch model model_path = "dominiqueblok/roberta-base-finetuned-ner" tokenizer = RobertaTokenizerFast.from_pretrained(model_path) - torch_model = RobertaForTokenClassification.from_pretrained(model_path) + torch_model = RobertaForTokenClassification.from_pretrained( + model_path, torchscript=True + ) # dummy inputs to generate the groq model batch_size, max_seq_length = 1, 128 @@ -53,6 +55,8 @@ def evaluate_roberta(rebuild_policy=None, should_execute=None): task="ner", ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_roberta(**parse_args()) diff --git a/proof_points/speech/m5/README.md b/proof_points/speech/m5/README.md index 9dd02d9..3940a4d 100644 --- a/proof_points/speech/m5/README.md +++ b/proof_points/speech/m5/README.md @@ -19,6 +19,19 @@ M5's Keyword Spotting accuracy is evaluated using the [SpeechCommands dataset](h pip install -r requirements.txt ``` +- Since this proofpoint uses audio files, often the audio libraries must be installed on system. + - For Ubuntu OS: + + ```bash + sudo apt install libsox-dev + ``` + + - For Rocky OS: + + ```bash + sudo dnf install sox-devel + ``` + ## Build and Evaluate To build and evaluate M5: diff --git a/proof_points/speech/m5/m5.py b/proof_points/speech/m5/m5.py index 4dd0b60..96c71cd 100644 --- a/proof_points/speech/m5/m5.py +++ b/proof_points/speech/m5/m5.py @@ -37,6 +37,8 @@ def evaluate_m5(rebuild_policy=None, should_execute=True): task="keyword_spotting", ) + print(f"Proof point {__file__} finished!") + if __name__ == "__main__": evaluate_m5(**parse_args()) diff --git a/pyproject.toml b/pyproject.toml index 8cf3256..638dd9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [build-system] requires = ["setuptools>=61.0"] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 645a278..3bbb6a9 100644 --- a/setup.py +++ b/setup.py @@ -15,21 +15,13 @@ exclude=["*.__pycache__.*"], ), install_requires=[ - "onnx>=1.11.0", - "onnxmltools==1.10.0", - "hummingbird-ml==0.4.4", + "mlagility==3.3.1", + "onnx==1.14.0", + "onnxruntime==1.15.1", + "protobuf==3.20.3", "scikit-learn==1.1.1", - "xgboost==1.6.1", - "onnxruntime>=1.11.0", - "paramiko==2.11.0", - "torch>=1.12.1", - "protobuf>=3.17.3", - "pyyaml>=5.4", + "torch==2.1.0", "typeguard==4.0.0", - "typing_extensions==4.5.0", - "protobuf==3.20.3", - "packaging>=21.3", - "mlagility==3.2.0", ], extras_require={ "tensorflow": ["tensorflow-cpu>=2.8.1", "tf2onnx>=1.12.0"],