Skip to content

Commit

Permalink
mount the dedicated storage for each function (#1408)
Browse files Browse the repository at this point in the history
* add mount /function_data space
  • Loading branch information
akihikokuroda authored Jul 26, 2024
1 parent b9b4f85 commit 843ee36
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ data:
- mountPath: /data
name: user-storage
subPath: {{`{{ user_id }}`}}
- mountPath: /function_data
name: user-storage
subPath: {{`{{ function_data }}`}}
env:
# Environment variables for Ray TLS authentication.
# See https://docs.ray.io/en/latest/ray-core/configure.html#tls-authentication for more details.
Expand Down Expand Up @@ -352,6 +355,9 @@ data:
- mountPath: /data
name: user-storage
subPath: {{`{{ user_id }}`}}
- mountPath: /function_data
name: user-storage
subPath: {{`{{ function_data }}`}}
{{- if .Values.useCertManager }}
- mountPath: /tmp/tls
name: cert-tls
Expand Down
17 changes: 10 additions & 7 deletions client/qiskit_serverless/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,22 +475,25 @@ def upload(self, program: QiskitFunction):
def get_jobs(self, **kwargs) -> List[Job]:
return self._job_client.list(**kwargs)

def files(self) -> List[str]:
return self._files_client.list()
def files(self, provider: Optional[str] = None) -> List[str]:
return self._files_client.list(provider)

def file_download(
self,
file: str,
target_name: Optional[str] = None,
download_location: str = "./",
provider: Optional[str] = None,
):
return self._files_client.download(file, download_location, target_name)
return self._files_client.download(
file, download_location, target_name, provider
)

def file_delete(self, file: str):
return self._files_client.delete(file)
def file_delete(self, file: str, provider: Optional[str] = None):
return self._files_client.delete(file, provider)

def file_upload(self, file: str):
return self._files_client.upload(file)
def file_upload(self, file: str, provider: Optional[str] = None):
return self._files_client.upload(file, provider)

def list(self, **kwargs) -> List[QiskitFunction]:
"""Returns list of available programs."""
Expand Down
18 changes: 12 additions & 6 deletions client/qiskit_serverless/core/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,18 @@ def __init__(self, host: str, token: str, version: str):
self._token = token

def download(
self, file: str, download_location: str, target_name: Optional[str] = None
self,
file: str,
download_location: str,
target_name: Optional[str] = None,
provider: Optional[str] = None,
) -> Optional[str]:
"""Downloads file."""
tracer = trace.get_tracer("client.tracer")
with tracer.start_as_current_span("files.download"):
with requests.get(
f"{self.host}/api/{self.version}/files/download/",
params={"file": file},
params={"file": file, "provider": provider},
stream=True,
headers={"Authorization": f"Bearer {self._token}"},
timeout=REQUESTS_TIMEOUT,
Expand All @@ -80,14 +84,15 @@ def download(
progress_bar.close()
return file_name

def upload(self, file: str) -> Optional[str]:
def upload(self, file: str, provider: Optional[str] = None) -> Optional[str]:
"""Uploads file."""
tracer = trace.get_tracer("client.tracer")
with tracer.start_as_current_span("files.upload"):
with open(file, "rb") as f:
with requests.post(
f"{self.host}/api/{self.version}/files/upload/",
files={"file": f},
data={"provider": provider},
stream=True,
headers={"Authorization": f"Bearer {self._token}"},
timeout=REQUESTS_TIMEOUT,
Expand All @@ -97,27 +102,28 @@ def upload(self, file: str) -> Optional[str]:
return "Upload failed"
return "Can not open file"

def list(self) -> List[str]:
def list(self, provider: Optional[str] = None) -> List[str]:
"""Returns list of available files to download produced by programs,"""
tracer = trace.get_tracer("client.tracer")
with tracer.start_as_current_span("files.list"):
response_data = safe_json_request(
request=lambda: requests.get(
f"{self.host}/api/{self.version}/files/",
params={"provider": provider},
headers={"Authorization": f"Bearer {self._token}"},
timeout=REQUESTS_TIMEOUT,
)
)
return response_data.get("results", [])

def delete(self, file: str) -> Optional[str]:
def delete(self, file: str, provider: Optional[str] = None) -> Optional[str]:
"""Deletes file uploaded or produced by the programs,"""
tracer = trace.get_tracer("client.tracer")
with tracer.start_as_current_span("files.delete"):
response_data = safe_json_request(
request=lambda: requests.delete(
f"{self.host}/api/{self.version}/files/delete/",
data={"file": file},
data={"file": file, "provider": provider},
headers={
"Authorization": f"Bearer {self._token}",
"format": "json",
Expand Down
6 changes: 5 additions & 1 deletion gateway/api/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def submit_job(job: Job) -> Job:
return job


def create_ray_cluster(
def create_ray_cluster( # pylint: disable=too-many-branches
job: Job,
cluster_name: Optional[str] = None,
cluster_data: Optional[str] = None,
Expand Down Expand Up @@ -250,14 +250,18 @@ def create_ray_cluster(
node_image = settings.RAY_NODE_IMAGE

# if user specified image use specified image
function_data = user.username
if job.program.image is not None:
node_image = job.program.image
if job.program.provider.name:
function_data = job.program.provider.name

cluster = get_template("rayclustertemplate.yaml")
manifest = cluster.render(
{
"cluster_name": cluster_name,
"user_id": user.username,
"function_data": function_data,
"node_image": node_image,
"workers": job_config.workers,
"min_workers": job_config.min_workers,
Expand Down
63 changes: 56 additions & 7 deletions gateway/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Program,
Job,
RuntimeJob,
Provider,
)
from .ray import get_job_handler
from .serializers import (
Expand Down Expand Up @@ -580,15 +581,39 @@ class FilesViewSet(viewsets.ViewSet):

BASE_NAME = "files"

def list_user_providers(self, user):
"""list provider names that the user in"""
provider_list = []
providers = Provider.objects.all()
for instance in providers:
if instance.admin_group in user.groups.all():
provider_list.append(instance.name)
return provider_list

def check_user_has_provider(self, user, provider_name):
"""check if user has the provider"""
return provider_name in self.list_user_providers(user)

def list(self, request):
"""List of available for user files."""
response = Response(
{"message": "Requested file was not found."},
status=status.HTTP_404_NOT_FOUND,
)
files = []
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.files.list", context=ctx):
user_dir = request.user.username
provider_name = request.query_params.get("provider")
if provider_name is not None:
if self.check_user_has_provider(request.user, provider_name):
user_dir = provider_name
else:
return response
user_dir = os.path.join(
sanitize_file_path(settings.MEDIA_ROOT),
sanitize_file_path(request.user.username),
sanitize_file_path(user_dir),
)
if os.path.exists(user_dir):
files = [
Expand All @@ -615,17 +640,23 @@ def download(self, request): # pylint: disable=invalid-name
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.files.download", context=ctx):
requested_file_name = request.query_params.get("file")
provider_name = request.query_params.get("provider")
if requested_file_name is not None:
user_dir = request.user.username
if provider_name is not None:
if self.check_user_has_provider(request.user, provider_name):
user_dir = provider_name
else:
return response
# look for file in user's folder
filename = os.path.basename(requested_file_name)
user_dir = os.path.join(
sanitize_file_path(settings.MEDIA_ROOT),
sanitize_file_path(request.user.username),
sanitize_file_path(user_dir),
)
file_path = os.path.join(
sanitize_file_path(user_dir), sanitize_file_path(filename)
)

if os.path.exists(user_dir) and os.path.exists(file_path) and filename:
chunk_size = 8192
# note: we do not use with statements as Streaming response closing file itself.
Expand Down Expand Up @@ -656,14 +687,20 @@ def delete(self, request): # pylint: disable=invalid-name
if request.data and "file" in request.data:
# look for file in user's folder
filename = os.path.basename(request.data["file"])
provider_name = request.data.get("provider")
user_dir = request.user.username
if provider_name is not None:
if self.check_user_has_provider(request.user, provider_name):
user_dir = provider_name
else:
return response
user_dir = os.path.join(
sanitize_file_path(settings.MEDIA_ROOT),
sanitize_file_path(request.user.username),
sanitize_file_path(user_dir),
)
file_path = os.path.join(
sanitize_file_path(user_dir), sanitize_file_path(filename)
)

if os.path.exists(user_dir) and os.path.exists(file_path) and filename:
os.remove(file_path)
response = Response(
Expand All @@ -675,20 +712,32 @@ def delete(self, request): # pylint: disable=invalid-name
@action(methods=["POST"], detail=False)
def upload(self, request): # pylint: disable=invalid-name
"""Upload selected file."""
response = Response(
{"message": "Requested file was not found."},
status=status.HTTP_404_NOT_FOUND,
)
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.files.download", context=ctx):
upload_file = request.FILES["file"]
filename = os.path.basename(upload_file.name)
user_dir = request.user.username
if request.data and "provider" in request.data:
provider_name = request.data["provider"]
if provider_name is not None:
if self.check_user_has_provider(request.user, provider_name):
user_dir = provider_name
else:
return response
user_dir = os.path.join(
sanitize_file_path(settings.MEDIA_ROOT),
sanitize_file_path(request.user.username),
sanitize_file_path(user_dir),
)
file_path = os.path.join(
sanitize_file_path(user_dir), sanitize_file_path(filename)
)
with open(file_path, "wb+") as destination:
for chunk in upload_file.chunks():
destination.write(chunk)
return Response({"message": file_path})
return Response({"message": file_path})
return Response("server error", status=status.HTTP_500_INTERNAL_SERVER_ERROR)
Loading

0 comments on commit 843ee36

Please sign in to comment.