Skip to content

Commit

Permalink
enable oauth
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Feb 29, 2024
1 parent d3d1644 commit a314daa
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 38 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"src/autotrain/templates/index.html",
"src/autotrain/templates/error.html",
"src/autotrain/templates/duplicate.html",
"src/autotrain/templates/login.html",
],
),
],
Expand Down
46 changes: 24 additions & 22 deletions src/autotrain/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
USE_OAUTH = int(os.environ.get("USE_OAUTH", "0"))

if "SPACE_ID" not in os.environ:
USE_OAUTH = 0

HIDDEN_PARAMS = [
"token",
"project_name",
Expand Down Expand Up @@ -241,8 +244,8 @@ async def read_form(request: Request):
if HF_TOKEN is None and USE_OAUTH == 0:
return templates.TemplateResponse("error.html", {"request": request})

if USE_OAUTH == 1 and HF_TOKEN is None:
return RedirectResponse("/login/huggingface")
if USE_OAUTH == 1 and HF_TOKEN is None and "oauth_info" not in request.session:
return templates.TemplateResponse("login.html", {"request": request})

if HF_TOKEN is None:
if os.environ.get("SPACE_ID") is None:
Expand All @@ -255,30 +258,23 @@ async def read_form(request: Request):
if USE_OAUTH == 1:
logger.info(request.session["oauth_info"])

_, _, USERS = app_utils.user_validation()
token = HF_TOKEN if USE_OAUTH == 0 else request.session["oauth_info"]["access_token"]

_users = app_utils.user_validation(user_token=token)
context = {
"request": request,
"valid_users": USERS,
"valid_users": _users,
"enable_ngc": ENABLE_NGC,
"enable_nvcf": ENABLE_NVCF,
"enable_local": AUTOTRAIN_LOCAL,
}
return templates.TemplateResponse("index.html", context)


@app.post("/set_token", response_class=JSONResponse)
async def set_token(request: Request, token: str):
"""
This function is used to set the token
:param request:
:param token: str
:return: JSONResponse
"""
if token.startswith("hf_"):
global HF_TOKEN
os.environ["HF_TOKEN"] = token
HF_TOKEN = token
return {"error": "Invalid token"}
@app.get("/logout", response_class=HTMLResponse)
async def oauth_logout(request: Request):
request.session.pop("oauth_info", None)
return RedirectResponse("/")


@app.get("/params/{task}", response_class=JSONResponse)
Expand Down Expand Up @@ -364,6 +360,7 @@ async def fetch_model_choices(task: str):

@app.post("/create_project", response_class=JSONResponse)
async def handle_form(
request: Request,
project_name: str = Form(...),
task: str = Form(...),
base_model: str = Form(...),
Expand All @@ -386,9 +383,14 @@ async def handle_form(
status_code=409, detail="Another job is already running. Please wait for it to finish."
)

if HF_TOKEN is None:
if HF_TOKEN is None and USE_OAUTH == 0:
return {"error": "HF_TOKEN not set"}

if USE_OAUTH == 1 and HF_TOKEN is None:
token = request.session["oauth_info"]["access_token"]
else:
token = HF_TOKEN

params = json.loads(params)
column_mapping = json.loads(column_mapping)

Expand All @@ -398,7 +400,7 @@ async def handle_form(
if task == "image-classification":
dset = AutoTrainImageClassificationDataset(
train_data=training_files[0],
token=HF_TOKEN,
token=token,
project_name=project_name,
username=autotrain_user,
valid_data=validation_files[0] if validation_files else None,
Expand All @@ -409,7 +411,7 @@ async def handle_form(
dset = AutoTrainDreamboothDataset(
concept_images=data_files_training,
concept_name=params["prompt"],
token=HF_TOKEN,
token=token,
project_name=project_name,
username=autotrain_user,
local=hardware.lower() == "local",
Expand Down Expand Up @@ -443,7 +445,7 @@ async def handle_form(
dset_args = dict(
train_data=training_files,
task=dset_task,
token=HF_TOKEN,
token=token,
project_name=project_name,
username=autotrain_user,
column_mapping=column_mapping,
Expand All @@ -457,7 +459,7 @@ async def handle_form(
data_path = dset.prepare()
app_params = AppParams(
job_params_json=json.dumps(params),
token=HF_TOKEN,
token=token,
project_name=project_name,
username=autotrain_user,
task=task,
Expand Down
53 changes: 45 additions & 8 deletions src/autotrain/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,42 @@ def kill_process_by_pid(pid):


def user_authentication(token):
if token.startswith("hf_oauth"):
_api_url = config.HF_API + "/oauth/userinfo"
else:
_api_url = config.HF_API + "/api/whoami-v2"
headers = {}
cookies = {}
if token.startswith("hf_"):
headers["Authorization"] = f"Bearer {token}"
else:
cookies = {"token": token}
try:
response = requests.get(
_api_url,
headers=headers,
cookies=cookies,
timeout=3,
)
except (requests.Timeout, ConnectionError) as err:
logger.error(f"Failed to request whoami-v2 - {repr(err)}")
raise Exception("Hugging Face Hub is unreachable, please try again later.")
resp = response.json()
user_info = {}
if "error" in resp:
return resp
if token.startswith("hf_oauth"):
user_info["id"] = resp["sub"]
user_info["name"] = resp["preferred_username"]
user_info["orgs"] = [resp["orgs"][k]["preferred_username"] for k in range(len(resp["orgs"]))]
else:
user_info["id"] = resp["id"]
user_info["name"] = resp["name"]
user_info["orgs"] = [resp["orgs"][k]["name"] for k in range(len(resp["orgs"]))]
return user_info


def user_authentication_deprecated(token):
logger.info("Authenticating user...")
headers = {}
cookies = {}
Expand All @@ -77,10 +113,9 @@ def user_authentication(token):
def _login_user(user_token):
user_info = user_authentication(token=user_token)
username = user_info["name"]

user_can_pay = user_info["canPay"]
orgs = user_info["orgs"]

user_can_pay = user_info["canPay"]
valid_orgs = [org for org in orgs if org["canPay"] is True]
valid_orgs = [org for org in valid_orgs if org["roleInOrg"] in ("admin", "write")]
valid_orgs = [org["name"] for org in valid_orgs]
Expand All @@ -90,18 +125,20 @@ def _login_user(user_token):
return user_token, valid_can_pay, who_is_training


def user_validation():
user_token = os.environ.get("HF_TOKEN", None)

def user_validation(user_token):
if user_token is None:
raise Exception("Please login with a write token.")

user_token, valid_can_pay, who_is_training = _login_user(user_token)

if user_token is None or len(user_token) == 0:
raise Exception("Invalid token. Please login with a write token.")

return user_token, valid_can_pay, who_is_training
user_info = user_authentication(token=user_token)
username = user_info["name"]
orgs = user_info["orgs"]

who_is_training = [username] + orgs

return who_is_training


def run_training(params, task_id, local=False, wait=False):
Expand Down
8 changes: 1 addition & 7 deletions src/autotrain/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import urllib.parse

import fastapi
from authlib.integrations.base_client.errors import MismatchingStateError
from authlib.integrations.starlette_client import OAuth
from fastapi.responses import RedirectResponse
from starlette.middleware.sessions import SessionMiddleware
Expand Down Expand Up @@ -82,12 +81,7 @@ async def oauth_login(request: fastapi.Request):
@app.get("/auth")
async def auth(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that handles the OAuth callback."""
# oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore
try:
oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore
except MismatchingStateError:
print("Session dict:", dict(request.session))
raise
oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore
request.session["oauth_info"] = oauth_info
return _redirect_to_target(request)

Expand Down
2 changes: 1 addition & 1 deletion src/autotrain/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
<select id="hardware" name="hardware"
class="mt-1 block w-full border border-gray-300 px-3 py-2 bg-white rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
{% if enable_local == 1 %}
<option value="Local">Local</option>
<option value="Local">Local/Space</option>
{% else %}
<optgroup label="Hugging Face Spaces">
<option value="A10G Large">A10G Large</option>
Expand Down
24 changes: 24 additions & 0 deletions src/autotrain/templates/login.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<!doctype html>
<html>

<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="https://cdn.tailwindcss.com"></script>
</head>

<body>
<header class="bg-white-800 text-white p-4">
<div class="container mx-auto flex justify-between items-center">
<img src="/static/logo.png" alt="AutoTrain" , class="w-32">
</div>
</header>

<div class="form-container max-w-lg mx-auto mt-10 p-6 shadow-2xl">
<h1 class="text-2xl font-bold mb-10">Login</h1>
<p class="text-gray-500 text-sm mb-10">Please <a href="/login/huggingface" target="_blank">login</a> to use
AutoTrain</p>
</div>
</body>

</html>

0 comments on commit a314daa

Please sign in to comment.