Skip to content

Commit

Permalink
client
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Nov 5, 2024
1 parent 3208f65 commit 2a39441
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions src/autotrain/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from dataclasses import dataclass
from typing import Optional
import os

import requests


"""
{
"project_name": "string",
Expand Down Expand Up @@ -51,6 +53,7 @@
}
"""


@dataclass
class Client:
host: Optional[str] = None
Expand All @@ -60,28 +63,36 @@ class Client:
def __post_init__(self):
if self.host is None:
self.host = "https://autotrain-projects-autotrain-advanced.hf.space/"

if self.token is None:
self.token = os.environ.get("HF_TOKEN")

if self.username is None:
self.username = os.environ.get("HF_USERNAME")

if self.token is None or self.username is None:
raise ValueError("Please provide a valid username and token")

self.headers = {
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json"
}


self.headers = {"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"}

def __str__(self):
return f"Client(host={self.host}, token=****, username={self.username})"

def __repr__(self):
return self.__str__()

def create(self, project_name: str, task: str, base_model: str, hardware: str, params: dict, column_mapping: dict, hub_dataset: str, train_split: str, valid_split: str):
return self.__str__()

def create(
self,
project_name: str,
task: str,
base_model: str,
hardware: str,
params: dict,
column_mapping: dict,
hub_dataset: str,
train_split: str,
valid_split: str,
):
url = f"{self.host}/api/create_project"
data = {
"project_name": project_name,
Expand All @@ -93,7 +104,7 @@ def create(self, project_name: str, task: str, base_model: str, hardware: str, p
"column_mapping": column_mapping,
"hub_dataset": hub_dataset,
"train_split": train_split,
"valid_split": valid_split
"valid_split": valid_split,
}
response = requests.post(url, headers=self.headers, json=data)
return response.json()

0 comments on commit 2a39441

Please sign in to comment.