Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Tenant ID #1097

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
20 changes: 20 additions & 0 deletions docs/source/contributors/system-architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -491,3 +491,23 @@ Theoretically speaking, enabling a kernel for use in other frameworks amounts to
```{seealso}
This topic is covered in the [Developers Guide](../developers/index.rst).
```

## Multiple Tenant Support

Enterprise Gateway offers viably minimal support multi-tenant environments by tracking the managed kernels by their tenant _ID_. This is accomplished on the client request when starting a kernel by adding a UUID-formatted string to the kernel start request's `env` stanza with an associated key of `JUPYTER_GATEWAY_TENANT_ID`.

```JSON
{"env": {"JUPYTER_GATEWAY_TENANT_ID": "f730794d-d175-40fa-b819-2a67d5308210"}}
```

Likewise, when calling the `/api/kernels` endpoint to get the list of active kernels, tenant-aware applications should add a `tenant_id` query parameter in order to get appropriate managed kernel information.

```
GET /api/kernels?tenant_id=f730794d-d175-40fa-b819-2a67d5308210
```

Kernel start requests that do not include a `tenant_id` will have their kernels associated with the `UNIVERSAL_TENANT_ID` which merely acts a catch-all and allows common code usage relative to existing clients.

Enterprise Gateway will add the environment variable `KERNEL_TENANT_ID` to the kernel's _launch_ environment so that this value is available to the kernel's launch framework. This value is currently not available to the kernel itself. It should be noted that if the original request also included a `KERNEL_TENANT_ID` in the body's `env` stanza, it will be overwritten with the value corresponding to `tenant_id` (or `UNIVERSAL_TENANT_ID` if `tenant_id` was not provided).

Kernel specifications and other resources do not currently adhere to tenant-based _partitioning_.
9 changes: 9 additions & 0 deletions docs/source/users/kernel-envs.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ There are several supported `KERNEL_` variables that the Enterprise Gateway serv
(https://googlecloudplatform.github.io/spark-on-k8s-operator/docs/api-docs.html#sparkoperator.k8s.io/v1beta2.SparkApplicationSpec)
sparkConfigMap for more information.

KERNEL_TENANT_ID=<system provided>
Indicates the tenant ID (UUID string) corresponding to the kernel. This value
is derived from the optional `tenant_id` provided by the client application and
is meant to represent the entity or organization in which the client application
is associated. If `tenant_id` is not provided on the kernel start request, then
`KERNEL_TENANT_ID` will hold a value of ``"27182818-2845-9045-2353-602874713527"`
(the `UNIVERSAL_TENANT_ID`), which provides for backwards compatible support for
older applications.

KERNEL_UID=<from user> or 1000
Containers only. This value represents the user id in which the container will run.
The default value is 1000 representing the jovyan user - which is how all kernel images
Expand Down
15 changes: 14 additions & 1 deletion enterprise_gateway/services/api/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ paths:
summary: List the JSON data for all currently running kernels
tags:
- kernels
parameters:
- name: tenant_id
in: query
description: When present, the list of running kernels will apply to the given tenant ID.
required: false
type: string
format: uuid
responses:
200:
description: List of running kernels
Expand All @@ -137,11 +144,17 @@ paths:
name:
type: string
description: Kernel spec name (defaults to default kernel spec for server)
tenant_id:
type: string
format: uuid
description: |
The (optional) UUID of the tenant making the request. The list of active
(running) kernels will be filtered by this value.
env:
type: object
description: |
A dictionary of environment variables and values to include in the
kernel process - subject to whitelisting.
kernel process - subject to filtering.
additionalProperties:
type: string
responses:
Expand Down
77 changes: 47 additions & 30 deletions enterprise_gateway/services/kernels/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import jupyter_server.services.kernels.handlers as jupyter_server_handlers
import tornado
from jupyter_client.jsonutil import date_default
from jupyter_server.utils import ensure_async
from tornado import web

from ...mixins import CORSMixin, JSONErrorsMixin, TokenAuthorizationMixin
from ..kernels.remotemanager import UNIVERSAL_TENANT_ID


class MainKernelHandler(
Expand Down Expand Up @@ -45,35 +47,42 @@ async def post(self):
if len(kernels) >= max_kernels:
raise tornado.web.HTTPError(403, "Resource Limit")

# Try to get env vars from the request body
env = {}
# Try to get tenant_id and env vars from the request body
model = self.get_json_body()
if model is not None and "env" in model:
if not isinstance(model["env"], dict):
raise tornado.web.HTTPError(400)
# Start with the PATH from the current env. Do not provide the entire environment
# which might contain server secrets that should not be passed to kernels.
env = {"PATH": os.getenv("PATH", "")}
# Whitelist environment variables from current process environment
env.update(
{
key: value
for key, value in os.environ.items()
if key in self.env_process_whitelist
}
)
# Whitelist KERNEL_* args and those allowed by configuration from client. If all
# envs are requested, just use the keys from the payload.
env_whitelist = self.env_whitelist
if env_whitelist == ["*"]:
env_whitelist = model["env"].keys()
env.update(
{
key: value
for key, value in model["env"].items()
if key.startswith("KERNEL_") or key in env_whitelist
}
)
if model is not None:
tenant_id = UNIVERSAL_TENANT_ID
if "env" in model:
if not isinstance(model["env"], dict):
raise tornado.web.HTTPError(400)
# Start with the PATH from the current env. Do not provide the entire environment
# which might contain server secrets that should not be passed to kernels.
env = {"PATH": os.getenv("PATH", "")}
# Whitelist environment variables from current process environment
env.update(
{
key: value
for key, value in os.environ.items()
if key in self.env_process_whitelist
}
)
# Whitelist KERNEL_* args and those allowed by configuration from client. If all
# envs are requested, just use the keys from the payload.
env_whitelist = self.env_whitelist
if env_whitelist == ["*"]:
env_whitelist = model["env"].keys()
env.update(
{
key: value
for key, value in model["env"].items()
if key.startswith("KERNEL_") or key in env_whitelist
}
)
# Pull client-provided tenant_id from env, defaulting to UNIVERSAL_TENANT_ID.
tenant_id = model["env"].get("JUPYTER_GATEWAY_TENANT_ID", tenant_id)

# Set KERNEL_TENANT_ID. If already present, we override with the value in the body
env["KERNEL_TENANT_ID"] = tenant_id
# If kernel_headers are configured, fetch each of those and include in start request
kernel_headers = {}
missing_headers = []
Expand All @@ -97,7 +106,9 @@ async def post(self):
# so do a temporary partial (ugh)
orig_start = self.kernel_manager.start_kernel
self.kernel_manager.start_kernel = partial(
self.kernel_manager.start_kernel, env=env, kernel_headers=kernel_headers
self.kernel_manager.start_kernel,
env=env,
kernel_headers=kernel_headers,
)
try:
await super().post()
Expand All @@ -119,8 +130,14 @@ async def get(self):
"""
if not self.settings.get("eg_list_kernels"):
raise tornado.web.HTTPError(403, "Forbidden")
else:
await super().get()

tenant_id_filter = self.request.query_arguments.get("tenant_id")
if isinstance(tenant_id_filter, list):
tenant_id_filter = tenant_id_filter[0].decode("utf-8")

km = self.kernel_manager
kernels = await ensure_async(km.list_kernels(tenant_id=tenant_id_filter))
self.finish(json.dumps(kernels, default=date_default))

def options(self, **kwargs):
"""Method for properly handling CORS pre-flight"""
Expand Down
54 changes: 41 additions & 13 deletions enterprise_gateway/services/kernels/remotemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import signal
import time
import uuid
from typing import Optional

from jupyter_client.ioloop.manager import AsyncIOLoopKernelManager
from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager
Expand All @@ -19,6 +20,8 @@
from ..processproxies.processproxy import LocalProcessProxy, RemoteProcessProxy
from ..sessions.kernelsessionmanager import KernelSessionManager

UNIVERSAL_TENANT_ID: str = "27182818-2845-9045-2353-602874713527"
kevin-bates marked this conversation as resolved.
Show resolved Hide resolved

default_kernel_launch_timeout = float(os.getenv("EG_KERNEL_LAUNCH_TIMEOUT", "30"))
kernel_restart_status_poll_interval = float(os.getenv("EG_RESTART_STATUS_POLL_INTERVAL", 1.0))

Expand Down Expand Up @@ -122,14 +125,19 @@ def new_kernel_id(**kwargs):

class TrackPendingRequests:
"""
Simple class to track (increment/decrement) pending kernel start requests, both total and per user.
Simple class to track (increment/decrement) pending kernel start requests, both total and per user.

This tracking is necessary due to an inherent race condition that occurs now that kernel startup is
asynchronous. As a result, multiple/simultaneous requests must be considered, in addition all existing
kernel sessions.

This class instance is essentially a singleton in that its single instance is contained within
the single instance of RemoteMappingKernelManager.
"""

_pending_requests_all = 0
_pending_requests_user = {}
def __init__(self):
self._pending_requests_all = 0
self._pending_requests_user = {}

def increment(self, username: str) -> None:
self._pending_requests_all += 1
Expand Down Expand Up @@ -181,7 +189,6 @@ async def start_kernel(self, *args, **kwargs):
kernel_name=kwargs["kernel_name"], username=username
)
)

# Check max kernel limits
self._enforce_kernel_limits(username)

Expand Down Expand Up @@ -243,7 +250,6 @@ def _enforce_kernel_limits(self, username: str) -> None:
pending_all, pending_user = RemoteMappingKernelManager.pending_requests.get_counts(
username
)

# Enforce overall limit...
if self.parent.max_kernels is not None:
active_and_pending = len(self.list_kernels()) + pending_all
Expand Down Expand Up @@ -286,8 +292,22 @@ def remove_kernel(self, kernel_id):
super().remove_kernel(kernel_id)
self.parent.kernel_session_manager.delete_session(kernel_id)

def list_kernels(self, tenant_id: Optional[str] = None):
"""Returns a list of kernel_id's of kernels running, filtering on tenant_id if provided."""
kernels = []
for kernel_id, kernel_manager in self._kernels.items():
try:
# If requested, discard non-matching tenant_id kernels
if tenant_id and tenant_id != kernel_manager.tenant_id:
continue
model = self.kernel_model(kernel_id)
kernels.append(model)
except (web.HTTPError, KeyError):
pass # Probably due to a (now) non-existent kernel, continue building the list
return kernels

def start_kernel_from_session(
self, kernel_id, kernel_name, connection_info, process_info, launch_args
self, kernel_id, kernel_name, connection_info, process_info, launch_args, tenant_id
):
"""
Starts a kernel from a persisted kernel session.
Expand All @@ -311,6 +331,8 @@ def start_kernel_from_session(
from persistent storage
launch_args : dict
The arguments used for the initial launch of the kernel
tenant_id : str
The tenant_id corresponding to this kernel
Returns
-------
True if kernel could be located and started, False otherwise.
Expand All @@ -329,6 +351,8 @@ def start_kernel_from_session(
kernel_name=kernel_name,
**constructor_kwargs,
)
# Set the persisted tenant_id
km.tenant_id = tenant_id

# Load connection info into member vars - no need to write out connection file
km.load_connection_info(connection_info)
Expand Down Expand Up @@ -385,13 +409,14 @@ def __init__(self, **kwargs):
self.public_key = None
self.sigint_value = None
self.kernel_id = None
self.tenant_id = None
self.user_overrides = {}
self.kernel_launch_timeout = default_kernel_launch_timeout
self.restarting = False # need to track whether we're in a restart situation or not

# If this instance supports port caching, then disable cache_ports since we don't need this
# for remote kernels and it breaks the ability to support port ranges for local kernels (which
# is viewed as more imporant for EG).
# for remote kernels, and it breaks the ability to support port ranges for local kernels (which
# is viewed as more important for EG).
# Note: This check MUST remain in this method since cache_ports is used immediately
# following construction.
if hasattr(self, "cache_ports"):
Expand Down Expand Up @@ -458,6 +483,9 @@ def _capture_user_overrides(self, **kwargs):
of the kernelspec env stanza that would have otherwise overridden the user-provided values.
"""
env = kwargs.get("env", {})
# Although this particular env is not "user-provided", we still handle its capture here.
self.tenant_id = env.get("KERNEL_TENANT_ID", UNIVERSAL_TENANT_ID)

# If KERNEL_LAUNCH_TIMEOUT is passed in the payload, override it.
self.kernel_launch_timeout = float(
env.get("KERNEL_LAUNCH_TIMEOUT", default_kernel_launch_timeout)
Expand Down Expand Up @@ -489,7 +517,7 @@ def format_kernel_cmd(self, extra_arguments=None):
if self.kernel_id:
ns["kernel_id"] = self.kernel_id

pat = re.compile(r"\{([A-Za-z0-9_]+)\}")
pat = re.compile(r"{([A-Za-z0-9_]+)}")

def from_ns(match):
"""Get the key out of ns if it's there, otherwise no change."""
Expand Down Expand Up @@ -532,7 +560,7 @@ def request_shutdown(self, restart=False):
super().request_shutdown(restart)

# If we're using a remote proxy, we need to send the launcher indication that we're
# shutting down so it can exit its listener thread, if its using one.
# shutting down so it can exit its listener thread, if it's using one.
if isinstance(self.process_proxy, RemoteProcessProxy):
self.process_proxy.shutdown_listener()

Expand Down Expand Up @@ -635,7 +663,7 @@ def cleanup(self, connection_file=True):
# remains here for pre-6.2.0 jupyter_client installations.

# Note we must use `process_proxy` here rather than `kernel`, although they're the same value.
# The reason is because if the kernel shutdown sequence has triggered its "forced kill" logic
# The reason is that if the kernel shutdown sequence has triggered its "forced kill" logic
# then that method (jupyter_client/manager.py/_kill_kernel()) will set `self.kernel` to None,
# which then prevents process proxy cleanup.
if self.process_proxy:
Expand All @@ -652,7 +680,7 @@ def cleanup_resources(self, restart=False):
# will not be called until jupyter_client 6.2.0 has been released.

# Note we must use `process_proxy` here rather than `kernel`, although they're the same value.
# The reason is because if the kernel shutdown sequence has triggered its "forced kill" logic
# The reason is that if the kernel shutdown sequence has triggered its "forced kill" logic
# then that method (jupyter_client/manager.py/_kill_kernel()) will set `self.kernel` to None,
# which then prevents process proxy cleanup.
if self.process_proxy:
Expand Down Expand Up @@ -685,7 +713,7 @@ def write_connection_file(self):

def _get_process_proxy(self):
"""
Reads the associated kernelspec and to see if has a process proxy stanza.
Reads the associated kernelspec and to see if it has a process proxy stanza.
If one exists, it instantiates an instance. If a process proxy is not
specified in the kernelspec, a LocalProcessProxy stanza is fabricated and
instantiated.
Expand Down
2 changes: 2 additions & 0 deletions enterprise_gateway/services/sessions/kernelsessionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def create_session(self, kernel_id, **kwargs):
# Compose the kernel_session entry
kernel_session = dict()
kernel_session["kernel_id"] = kernel_id
kernel_session["tenant_id"] = km.tenant_id
kernel_session["username"] = KernelSessionManager.get_kernel_username(**kwargs)
kernel_session["kernel_name"] = km.kernel_name

Expand Down Expand Up @@ -183,6 +184,7 @@ def _start_session(self, kernel_session):
connection_info=kernel_session["connection_info"],
process_info=kernel_session["process_info"],
launch_args=kernel_session["launch_args"],
tenant_id=kernel_session["tenant_id"],
)
if not kernel_started:
return False
Expand Down
Loading