Skip to content

Commit

Permalink
Modify paths, requests for HF proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
liveaverage committed Mar 7, 2024
1 parent c9aeb46 commit e7d7e1c
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions src/autotrain/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,8 @@ class NVCFRunner:
backend: str

def __post_init__(self):
self.token = None
self.nvcf_api = os.environ.get("NVCF_API")
self.nvcf_token = os.environ.get("NVCF_API_TOKEN")

self.hf_token = self.env_vars["HF_TOKEN"]
self.instance_map = {
"nvcf-l40": {"id": "67bb8939-c932-429a-a446-8ae898311856"},
"nvcf-h100x1": {"id": "848348f8-a4e2-4242-bce9-6baa1bd70a66"},
Expand Down Expand Up @@ -438,12 +436,13 @@ def _conf_nvcf(self, token, nvcf_type, url, method="POST", payload=None):
logger.info(
f"{self.job_name}: {method} - Successfully submitted NVCF job. Polling reqId for completion"
)
nvcf_reqid = response.headers.get("NVCF-REQID")
response_data = response.json()
nvcf_reqid = response_data.get("nvcfRequestId")
if nvcf_reqid:
logger.info(f"{self.job_name}: NVCF-REQID: {nvcf_reqid}")
logger.info(f"{self.job_name}: nvcfRequestId: {nvcf_reqid}")
return nvcf_reqid
else:
logger.warning(f"{self.job_name}: NVCF-REQID header is missing in the response")
logger.warning(f"{self.job_name}: nvcfRequestId key is missing in the response body")
return None

result = response.json()
Expand Down Expand Up @@ -503,7 +502,7 @@ def _poll_nvcf(self, url, token, method="get", timeout=86400, interval=30, op="p
raise TimeoutError(f"Operation '{op}' did not complete successfully within the timeout period.")

def create(self):
nvcf_url_submit = f"{self.nvcf_api}/v2/nvcf/pexec/functions/{self.instance_map[self.backend]['id']}"
nvcf_url_submit = f"{self.nvcf_api}/invoke/{self.instance_map[self.backend]['id']}"
nvcf_fr_payload = {
"cmd": [
"conda",
Expand All @@ -523,9 +522,9 @@ def create(self):
}

nvcf_fn_req = self._conf_nvcf(
token=self.nvcf_token, nvcf_type="job_submit", url=nvcf_url_submit, method="POST", payload=nvcf_fr_payload
token=self.hf_token, nvcf_type="job_submit", url=nvcf_url_submit, method="POST", payload=nvcf_fr_payload
)

nvcf_url_reqpoll = f"{self.nvcf_api}/v2/nvcf/pexec/status/{nvcf_fn_req}"
nvcf_url_reqpoll = f"{self.nvcf_api}/status/{nvcf_fn_req}"
logger.info(f"{self.job_name}: Polling : {nvcf_url_reqpoll}")
self._poll_nvcf(url=nvcf_url_reqpoll, token=self.nvcf_token, method="GET", timeout=172800, interval=20)
self._poll_nvcf(url=nvcf_url_reqpoll, token=self.hf_token, method="GET", timeout=172800, interval=20)

0 comments on commit e7d7e1c

Please sign in to comment.