Skip to content

Commit

Permalink
Merge pull request #61 from kerrlabajo/features
Browse files Browse the repository at this point in the history
Minor revisions in docker to add proper failure reason in training job
  • Loading branch information
kerrlabajo authored May 15, 2024
2 parents 81b5a49 + 95f4dac commit 868cbeb
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 31 deletions.
22 changes: 17 additions & 5 deletions Form1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,11 @@ public MainForm(bool development)
UserConnectionInfo.Region = Environment.GetEnvironmentVariable("REGION");
UserConnectionInfo.RoleArn = Environment.GetEnvironmentVariable("ROLE_ARN");
UserConnectionInfo.EcrUri = Environment.GetEnvironmentVariable("INTELLISYS_ECR_URI");
UserConnectionInfo.SagemakerBucket = Environment.GetEnvironmentVariable("SAGEMAKER_BUCKET");
UserConnectionInfo.DefaultDatasetURI = Environment.GetEnvironmentVariable("DEFAULT_DATASET_URI");
UserConnectionInfo.CustomUploadsURI = Environment.GetEnvironmentVariable("CUSTOM_UPLOADS_URI");
UserConnectionInfo.DestinationURI = Environment.GetEnvironmentVariable("DESTINATION_URI");

UserConnectionInfo.SagemakerBucket = $"sagemaker-{UserConnectionInfo.Region}-{UserConnectionInfo.AccountId}";
UserConnectionInfo.DefaultDatasetURI = $"s3://{UserConnectionInfo.SagemakerBucket}/default-datasets/MMX059XA_COVERED5B/";
UserConnectionInfo.CustomUploadsURI = $"s3://{UserConnectionInfo.SagemakerBucket}/users/{UserConnectionInfo.UserName}/custom-uploads/";
UserConnectionInfo.DestinationURI = $"s3://{UserConnectionInfo.SagemakerBucket}/users/{UserConnectionInfo.UserName}/training-jobs/";
MessageBox.Show("Established Connection using ENV for Development", "Success", MessageBoxButtons.OK, MessageBoxIcon.Information);
}
else if (!development && UserConnectionInfo.AccountId == null && UserConnectionInfo.AccessKey == null && UserConnectionInfo.SecretKey == null && UserConnectionInfo.Region == null && UserConnectionInfo.RoleArn == null)
Expand Down Expand Up @@ -460,7 +461,8 @@ private void backgroundWorker_ProgressChanged(object sender, System.ComponentMod
/// <param name="e">An instance of RunWorkerCompletedEventArgs containing event data.</param>
private void backgroundWorker_RunWorkerCompleted(object sender, System.ComponentModel.RunWorkerCompletedEventArgs e)
{
MessageBox.Show("Upload completed!");
if (progressBar.Value >= 100)
MessageBox.Show("Upload completed!");
progressBar.Value = 0;
mainPanel.Enabled = true;
logPanel.Enabled = true;
Expand Down Expand Up @@ -981,6 +983,10 @@ private void CalculateBatchSize()
{
idealBatchSize = -1;
}
else if (!supportedInstances.Contains(instance))
{
idealBatchSize = 16 * instanceCount;
}
else
{
idealBatchSize = 16 * instanceCount * gpuCount;
Expand Down Expand Up @@ -1047,6 +1053,12 @@ private bool ValidateTrainingParameters(string img_size, string batch_size, stri
}
}

if (!supportedInstances.Contains(selectedInstance) && Int32.TryParse(instanceCount, out int instance) && instance > 1)
{
MessageBox.Show("Multi-instance training does not support instances with no GPU", "Validation Error", MessageBoxButtons.OK, MessageBoxIcon.Error);
return false;
}

if (!Int32.TryParse(txtBatchSize.Text, out int batchSize))
{
MessageBox.Show("Batch size must be an integer.", "Validation Error", MessageBoxButtons.OK, MessageBoxIcon.Error);
Expand Down
5 changes: 0 additions & 5 deletions Functions/AWS_Helper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ namespace LSC_Trainer.Functions
/// </summary>
public class AWS_Helper
{
/// <summary>
/// Represents the total size of data uploaded.
/// </summary>
private static long totalUploaded = 0;

/// <summary>
/// Validates the provided access key ID by retrieving the username associated with it using the IAM client.
/// Updates the UserConnectionInfo.UserName property with the retrieved username if the key is valid.
Expand Down
141 changes: 134 additions & 7 deletions Functions/FileTransferUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
using SharpCompress.Common;
using SharpCompress.Readers;
using System.Threading;
using System.Net;
using Amazon.Runtime;

namespace LSC_Trainer.Functions
{
Expand All @@ -22,6 +24,8 @@ internal class FileTransferUtility : IFileTransferUtility
private IUIUpdater UIUpdater { get; set; }

private CancellationTokenSource cancellationTokenSource = new CancellationTokenSource();

private bool isMessageBoxShown = false;
public FileTransferUtility(IUIUpdater uIUpdater)
{
UIUpdater = uIUpdater;
Expand Down Expand Up @@ -92,7 +96,7 @@ public async Task<string> UploadFileToS3(AmazonS3Client s3Client, string filePat
using (TransferUtility transferUtility = new TransferUtility(s3Client))
{
var uploadRequest = CreateUploadRequest(filePath, fileName, bucketName);
ConfigureProgressTracking(uploadRequest, progress, totalSize, UIUpdater,cancellationTokenSource.Token);
ConfigureProgressTracking(uploadRequest, progress, totalSize, UIUpdater, cancellationTokenSource.Token);

await transferUtility.UploadAsync(uploadRequest, cancellationTokenSource.Token);

Expand All @@ -101,14 +105,63 @@ public async Task<string> UploadFileToS3(AmazonS3Client s3Client, string filePat
UIUpdater.UpdateTrainingStatus($"Uploading Files to S3", $"Uploading {totalUploaded}/{totalSize} - {overallPercentage}%");
}
}


LogUploadTime(startTime);
return fileName;
}
catch (AmazonS3Exception e)
{
LogError("Error uploading file to S3: ", e);
if (e.ErrorCode == "RequestTimeTooSkewed")
{
if (!isMessageBoxShown)
{
isMessageBoxShown = true;
MessageBox.Show($"Error uploading file to S3: A file took too long to upload. The difference between the request time and the current time is too large.", "Error", MessageBoxButtons.OK, MessageBoxIcon.Error);
isMessageBoxShown = false;
}
else
{
Console.WriteLine($"Error uploading file to S3: A file took too long to upload. The difference between the request time and the current time is too large.");
}
}
else
{
if (!isMessageBoxShown)
{
isMessageBoxShown = true;
MessageBox.Show($"Error uploading file to S3: {e}", "Error", MessageBoxButtons.OK, MessageBoxIcon.Error);
isMessageBoxShown = false;
}
else
{
LogError("Error uploading file to S3: ", e);
}
}
cancellationTokenSource.Cancel();
return null;
}
catch (AmazonServiceException e)
{
if (e.InnerException is WebException webEx && webEx.Status == WebExceptionStatus.NameResolutionFailure)
{
// Handle the NameResolutionFailure exception
if (!isMessageBoxShown)
{
isMessageBoxShown = true;
MessageBox.Show($"Error in uploading file to S3: Failed to resolve the hostname. Please check your network connection and the hostname.", "Error", MessageBoxButtons.OK, MessageBoxIcon.Error);
isMessageBoxShown = false;
}
else
{
Console.WriteLine($"Error in uploading file to S3: Failed to resolve the hostname. Please check your network connection and the hostname.");
}
}
else
{
LogError("Error uploading file to S3: An error occurred within the AWS SDK.", e);
}
cancellationTokenSource.Cancel();
return null;
}
catch (OperationCanceledException e)
Expand All @@ -118,7 +171,17 @@ public async Task<string> UploadFileToS3(AmazonS3Client s3Client, string filePat
}
catch (Exception e)
{
LogError("Error uploading file to S3: ", e);
if (!isMessageBoxShown)
{
isMessageBoxShown = true;
MessageBox.Show($"Error uploading file to S3: {e.Message}", "Error", MessageBoxButtons.OK, MessageBoxIcon.Error);
isMessageBoxShown = false;
}
else
{
LogError("Error uploading file to S3: ", e);
}
cancellationTokenSource.Cancel();
return null;
}
}
Expand Down Expand Up @@ -164,18 +227,82 @@ public async Task<string> UploadFileToS3(AmazonS3Client s3Client, MemoryStream f
}
catch (AmazonS3Exception e)
{
LogError("Error uploading file to S3: ", e);
if (e.ErrorCode == "RequestTimeTooSkewed")
{
if (!isMessageBoxShown)
{
isMessageBoxShown = true;
MessageBox.Show($"Error uploading file to S3: A file took too long to upload. The difference between the request time and the current time is too large.", "Error", MessageBoxButtons.OK, MessageBoxIcon.Error);
isMessageBoxShown = false;
}
else
{
Console.WriteLine($"Error uploading file to S3: A file took too long to upload. The difference between the request time and the current time is too large.");
}
}
else
{
if (!isMessageBoxShown)
{
isMessageBoxShown = true;
MessageBox.Show($"Error uploading file to S3: {e}", "Error", MessageBoxButtons.OK, MessageBoxIcon.Error);
isMessageBoxShown = false;
}
else
{
LogError("Error uploading file to S3: ", e);
}
}
cancellationTokenSource.Cancel();
return null;
}
catch (AmazonServiceException e)
{
if (e.InnerException is WebException webEx && webEx.Status == WebExceptionStatus.NameResolutionFailure)
{
// Handle the NameResolutionFailure exception
if (!isMessageBoxShown)
{
isMessageBoxShown = true;
MessageBox.Show($"Error in uploading file to S3: Failed to resolve the hostname. Please check your network connection and the hostname.", "Error", MessageBoxButtons.OK, MessageBoxIcon.Error);
isMessageBoxShown = false;
}
else
{
Console.WriteLine($"Error in uploading file to S3: Failed to resolve the hostname. Please check your network connection and the hostname.");
}
}
else
{
LogError("Error uploading file to S3: An error occurred within the AWS SDK.", e);
}
cancellationTokenSource.Cancel();
return null;
}
catch (OperationCanceledException e)
{
LogError("File Upload has been cancelled: ", e);
return null;
}
catch (Exception e)
{
LogError("Error uploading file to S3: ", e);
if (!isMessageBoxShown)
{
isMessageBoxShown = true;
MessageBox.Show($"Error uploading file to S3: {e}", "Error", MessageBoxButtons.OK, MessageBoxIcon.Error);
isMessageBoxShown = false;
}
else
{
LogError("Error uploading file to S3: ", e);
}
cancellationTokenSource.Cancel();
return null;
}
}

/// <summary>
/// Extracts the contents of a ZIP file into Memory Stream and uploads them to Amazon S3 asynchronously.
/// Gets the contents of a folder and uploads them to Amazon S3 asynchronously.
/// </summary>
/// <param name="s3Client">The Amazon S3 client instance.</param>
/// <param name="bucketName">The name of the S3 bucket where the files will be uploaded.</param>
Expand Down
2 changes: 1 addition & 1 deletion LSC-Trainer.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
<UpdatePeriodically>false</UpdatePeriodically>
<UpdateRequired>false</UpdateRequired>
<MapFileExtensions>true</MapFileExtensions>
<ApplicationRevision>10</ApplicationRevision>
<ApplicationRevision>12</ApplicationRevision>
<ApplicationVersion>1.1.1.%2a</ApplicationVersion>
<UseApplicationTrust>false</UseApplicationTrust>
<PublishWizardCompleted>true</PublishWizardCompleted>
Expand Down
27 changes: 14 additions & 13 deletions docker/yolov5-training/train_and_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import sys
import traceback
import re

def get_hosts_and_node_rank():
"""
Expand Down Expand Up @@ -35,10 +36,17 @@ def run_script(args, use_module=False):
Returns:
`None`
"""
if use_module:
subprocess.run(["python3", "-m"] + args, check=True)
else:
subprocess.run(["python3"] + args, check=True)
try:
if use_module:
subprocess.run(["python3", "-m"] + args, check=True)
else:
subprocess.run(["python3"] + args, check=True)
except Exception as e:
instructions = "Please refer to your AWS Console Management -> SageMaker -> Training Jobs -> <Job Name> -> Monitor Section -> View Logs -> `/aws/sagemaker/TrainingJobs` Log group -> <Log Stream> -> Select host `algo-1` for more information."
with open("/opt/ml/output/failure", "w") as f:
f.write(instructions)
print(traceback.format_exc())
sys.exit(1)

def parse_arguments():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -121,7 +129,7 @@ def main():
"--data", args.data, "--hyp", "/opt/ml/input/config/custom-hyps.yaml" if args.hyp == "Custom" else args.hyp,
"--project", args.project, "--name", args.name,
"--patience", args.patience, "--workers", args.workers, "--optimizer", args.optimizer,
"--device", args.device, "--cache", "--exist-ok",
"--device", args.device, "--cache", "disk", "--exist-ok",
]
export_args = [
"/code/yolov5/export.py", "--img-size", args.img_size,
Expand All @@ -145,11 +153,4 @@ def main():
shutil.copy2("/opt/ml/output/data/results/weights/best.onnx", "/opt/ml/model/")

if __name__ == "__main__":
try:
main()
except Exception as e:
with open("/opt/ml/output/failure", "w") as f:
print(e)
f.write(str(e))
f.write(traceback.format_exc())
sys.exit(1)
main()

0 comments on commit 868cbeb

Please sign in to comment.