Skip to content

Commit

Permalink
fix ecr
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Dec 7, 2023
1 parent 08aa340 commit 9295ec1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 92 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/build_and_push.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Build and Push Docker Images
name: Build & Push

on:
push:
Expand All @@ -12,7 +12,7 @@ env:

jobs:
dockerhub:
name: Build and push to Docker Hub
name: Docker Hub
runs-on: ubuntu-latest
steps:
- name: Check out the repo
Expand All @@ -30,8 +30,9 @@ jobs:
- name: Build and Push Docker Image
run: make docker
ecr:
name: Build and push to Amazon ECR
name: Amazon ECR
runs-on: ubuntu-latest
depends-on: dockerhub
steps:
- name: Checkout
uses: actions/checkout@v4
Expand All @@ -42,7 +43,6 @@ jobs:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
role-to-assume: ${{ secrets.AWS_ROLE }}

- name: Login to Amazon ECR
id: login-ecr
Expand Down
89 changes: 1 addition & 88 deletions src/autotrain/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,93 +160,6 @@ def _dreambooth_munge_data(params, username):
return params.image_path


@dataclass
class EndpointsRunner:
params: Union[TextClassificationParams, ImageClassificationParams, LLMTrainingParams]
backend: str

def __post_init__(self):
self.endpoints_backends = {
"ep-aws-useast1-s": "aws_us-east-1_gpu_small_g4dn.xlarge",
"ep-aws-useast1-m": "aws_us-east-1_gpu_medium_g5.2xlarge",
"ep-aws-useast1-l": "aws_us-east-1_gpu_large_g4dn.12xlarge",
"ep-aws-useast1-xl": "aws_us-east-1_gpu_xlarge_p4de",
"ep-aws-useast1-2xl": "aws_us-east-1_gpu_2xlarge_p4de",
"ep-aws-useast1-4xl": "aws_us-east-1_gpu_4xlarge_p4de",
"ep-aws-useast1-8xl": "aws_us-east-1_gpu_8xlarge_p4de",
}
if self.params.repo_id is not None:
self.username = self.params.repo_id.split("/")[0]
elif self.params.username is not None:
self.username = self.params.username
else:
raise ValueError("Must provide either repo_id or username")
self.api_url = f"https://api.endpoints.huggingface.cloud/v2/endpoint/{self.username}"
if isinstance(self.params, LLMTrainingParams):
self.task_id = 9

def _create_endpoint(self):
hardware = self.endpoints_backends[self.backend]
accelerator = hardware.split("_")[2]
instance_size = hardware.split("_")[3]
region = hardware.split("_")[1]
vendor = hardware.split("_")[0]
instance_type = hardware.split("_")[4]
payload = {
"accountId": self.username,
"compute": {
"accelerator": accelerator,
"instanceSize": instance_size,
"instanceType": instance_type,
"scaling": {"maxReplica": 1, "minReplica": 1},
},
"model": {
"framework": "custom",
"image": {
"custom": {
"env": {
"HF_TOKEN": self.params.token,
"AUTOTRAIN_USERNAME": self.username,
"PROJECT_NAME": self.params.project_name,
"PARAMS": self.params.model_dump_json(),
"DATA_PATH": self.params.data_path,
"TASK_ID": str(self.task_id),
"MODEL": self.params.model,
"OUTPUT_MODEL_REPO": self.params.repo_id,
"ENDPOINT_ID": f"{self.username}/{self.params.project_name}",
},
"health_route": "/",
"port": 7860,
"url": "public.ecr.aws/z4c3o6n6/autotrain:latest",
}
},
"repository": "autotrain-projects/autotrain-advanced",
"revision": "main",
"task": "custom",
},
"name": self.params.project_name,
"provider": {"region": region, "vendor": vendor},
"type": "protected",
}
headers = {"Authorization": f"Bearer {self.params.token}"}
r = requests.post(self.api_url, json=payload, headers=headers)
logger.info(r.json())
return r.json()

def prepare(self):
if isinstance(self.params, LLMTrainingParams):
data_path = _llm_munge_data(self.params, self.username)
self.params.data_path = data_path
endpoint_id = self._create_endpoint()
return endpoint_id
if isinstance(self.params, TextClassificationParams):
data_path = _text_clf_munge_data(self.params, self.username)
self.params.data_path = data_path
endpoint_id = self._create_endpoint()
return endpoint_id
raise NotImplementedError


@dataclass
class SpaceRunner:
params: Union[
Expand Down Expand Up @@ -410,7 +323,7 @@ def _create_endpoint(self):
},
"health_route": "/",
"port": 7860,
"url": "public.ecr.aws/z4c3o6n6/autotrain:latest",
"url": "public.ecr.aws/z4c3o6n6/autotrain-api:latest",
}
},
"repository": "autotrain-projects/autotrain-advanced",
Expand Down

0 comments on commit 9295ec1

Please sign in to comment.