Skip to content

Commit

Permalink
refactor the code
Browse files Browse the repository at this point in the history
  • Loading branch information
bachvudinh committed Jul 16, 2024
1 parent 75b9d7f commit 4a311d0
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 60 deletions.
Empty file added .github/workflows/test-flow.yml
Empty file.
3 changes: 2 additions & 1 deletion example_env.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export S3_ACCESS_KEY=minioadmin
export S3_SECRET_KEY=minioadmin
export S3_ENDPOINT_URL=http://127.0.0.1:9000
export S3_ENDPOINT_URL="http://127.0.0.1:9000"
mc alias set ALIAS $S3_ENDPOINT_URL $S3_ACCESS_KEY $S3_SECRET_KEY
107 changes: 58 additions & 49 deletions s3helper/s3_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,24 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import sys
import logging
from datasets import load_dataset, Dataset
from datasets import load_dataset, Dataset, load_from_disk
from typing import Optional, Dict, Any

# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(asctime)s - %(message)s')

def find_files(directory: str, file_format: str):
matching_files = []

# Walk through the directory
for root, _, files in os.walk(directory):
for file in files:
# Check if the file ends with the specified format
if file.endswith(f".{file_format}"):
matching_files.append(os.path.join(root, file))

return matching_files

class S3Helper:
_instance = None

Expand Down Expand Up @@ -46,7 +58,7 @@ def validate_credentials(self):
logging.error(f"Invalid S3 credentials: {e}")
raise ValueError("Invalid S3 credentials")

def download_model(self, path_components: list, local_dir: str = './models'):
def download_file(self, path_components: list, local_dir: str):
bucket_name = path_components[0]
model_name = path_components[1]
objects = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=model_name)
Expand All @@ -59,38 +71,41 @@ def download_model(self, path_components: list, local_dir: str = './models'):
self.s3_client.download_file(bucket_name, file_key, file_path)
logging.info(f'Downloaded file: {file_key}')

def ensure_model_local(self, pretrained_model_name_or_path, local_dir):
path_components = pretrained_model_name_or_path.split("/")
def ensure_file_local(self, file_name_or_path: str, local_dir: str):
path_components = file_name_or_path.split("/")
if len(path_components) != 2:
logging.error("Cannot recognize bucket name and model name since having > 2 components")
raise ValueError("Cannot recognize bucket name and model name since having > 2 components")
model_local_path = os.path.join(local_dir, pretrained_model_name_or_path)
if not os.path.exists(model_local_path):
os.makedirs(model_local_path, exist_ok=True)
self.download_model(path_components, local_dir)
logging.error("Cannot recognize bucket name and file name since the components are not 2")
raise ValueError("Cannot recognize bucket name and file name since the components are not 2")
file_local_path = os.path.join(local_dir, file_name_or_path)
if not os.path.exists(file_local_path):
os.makedirs(file_local_path, exist_ok=True)
self.download_file(path_components, local_dir)
else:
logging.info(f"Model existed at: {model_local_path}, read from cache")
return model_local_path
if 'model' in file_name_or_path:
logging.info(f"Model existed at: {file_local_path}, read from cache")
elif 'dataset' in file_name_or_path:
logging.info(f"Dataset existed at: {file_local_path}, read from cache")
return file_local_path

def upload_to_s3(self, local_dir, bucket_name, model_name):
def upload_to_s3(self, local_dir, bucket_name, file_name):
for root, _, files in os.walk(local_dir):
for file in files:
local_file_path = os.path.join(root, file)
s3_key = os.path.relpath(local_file_path, local_dir)
self.s3_client.upload_file(local_file_path, bucket_name, os.path.join(model_name, s3_key))
logging.info(f'Uploaded {local_file_path} to s3://{bucket_name}/{model_name}/{s3_key}')
def download_dataset(self, path_components: list, local_dir: str = './datasets'):
bucket_name = path_components[0]
dataset_name = path_components[1]
objects = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=dataset_name)
for obj in objects.get('Contents', []):
file_key = obj['Key']
if file_key.endswith('/'):
continue # Skip directories
file_path = os.path.join(local_dir, bucket_name, file_key)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
self.s3_client.download_file(bucket_name, file_key, file_path)
logging.info(f'Downloaded dataset file: {file_key}')
self.s3_client.upload_file(local_file_path, bucket_name, os.path.join(file_name, s3_key))
logging.info(f'Uploaded {local_file_path} to s3://{bucket_name}/{file_name}/{s3_key}')
# def download_dataset(self, path_components: list, local_dir: str = './datasets'):
# bucket_name = path_components[0]
# dataset_name = path_components[1]
# objects = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=dataset_name)
# for obj in objects.get('Contents', []):
# file_key = obj['Key']
# if file_key.endswith('/'):
# continue # Skip directories
# file_path = os.path.join(local_dir, bucket_name, file_key)
# os.makedirs(os.path.dirname(file_path), exist_ok=True)
# self.s3_client.download_file(bucket_name, file_key, file_path)
# logging.info(f'Downloaded dataset file: {file_key}')

class S3HelperAutoModelForCausalLM(AutoModelForCausalLM):
@classmethod
Expand All @@ -114,9 +129,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, local_dir:
return super().from_pretrained(config_local_path, *model_args, **kwargs)
# defined a custom load_dataset from S3 bucket
def s3_load_dataset(
path: str,
local_dir: str = './datasets',
dataset_name_or_path: str,
file_format: str = 'json',
local_dir: str = './datasets',
*args: Any,
**kwargs: Any
) -> Dataset:
Expand All @@ -125,33 +140,27 @@ def s3_load_dataset(
Args:
path (str): Path to the dataset in the format 'bucket_name/dataset_name'
file_format: File format of the dataset. Either 'json' or 'csv' or 'parquet'.
local_dir (str): Local directory to store downloaded datasets
file_format (str): Format of the dataset file (e.g., 'json', 'csv', 'parquet')
*args: Additional positional arguments to pass to load_dataset
**kwargs: Additional keyword arguments to pass to load_dataset
Returns:
Dataset: The loaded dataset
"""
s3_helper = S3Helper.get_instance()

# Split the path into bucket and dataset name
path_components = path.split("/")
if len(path_components) != 2:
raise ValueError("Path should be in the format 'bucket_name/dataset_name'")

bucket_name, dataset_name = path_components
dataset_local_path = os.path.join(local_dir, bucket_name, dataset_name)

# Download dataset if not exists locally
if not os.path.exists(dataset_local_path):
os.makedirs(dataset_local_path, exist_ok=True)
s3_helper.download_dataset(path_components, local_dir)
else:
logging.info(f"Dataset already exists at: {dataset_local_path}, using cached version")

# Construct the path to the data file
data_file_path = os.path.join(dataset_local_path, f"data.{file_format}")

dataset_local_path = ensure_file_local(dataset_name_or_path, local_dir)
local_files = find_files(dataset_local_path, file_format)
dataset_local_paths = [os.path.join(dataset_local_path, file) for file in local_files]
train_local_paths = []
test_local_paths = []
for file in dataset_local_paths:
if "train" in file:
train_local_paths.append(file)
elif "test" in file:
test_local_paths.append(file)
else:
raise ValueError("Not Implemented")
# Load and return the dataset
return load_dataset(file_format, data_files=data_file_path, *args, **kwargs)
return load_dataset(file_format, data_files={'train': train_local_paths, "test": test_local_paths}, *args, **kwargs)
13 changes: 3 additions & 10 deletions main.py → test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

os.environ['S3_ACCESS_KEY'] = 'minioadmin'
os.environ['S3_SECRET_KEY'] = 'minioadmin'
os.environ['S3_ENDPOINT_URL'] = 'http://172.17.0.2:9001'
os.environ['S3_ENDPOINT_URL'] = 'http://172.17.0.2:9000'
S3Helper()

# # Example usage
Expand All @@ -12,12 +12,5 @@
# tokenizer = S3HelperAutoTokenizer.from_pretrained(model_name)
# config = S3HelperAutoConfig.from_pretrained(model_name)
# Make sure S3Helper is initialized and environment variables are set
# Load a dataset
dataset = s3_load_dataset("modelhubjan/test_dataset")

# Use the dataset
for item in dataset:
print(item)

# You can also pass additional arguments to load_dataset
dataset = s3_load_dataset("modelhubjan/test_dataset", file_format='parquet', split='train')
# Load a dataset from S3 bucket
dataset = s3_load_dataset("jan-hq/test_dataset",file_format='parquet', split='train')

0 comments on commit 4a311d0

Please sign in to comment.