diff --git a/bioblend/galaxy/__init__.py b/bioblend/galaxy/__init__.py index 937b23d28..112ef6000 100644 --- a/bioblend/galaxy/__init__.py +++ b/bioblend/galaxy/__init__.py @@ -3,6 +3,8 @@ """ from typing import Optional +import requests + from bioblend.galaxy import ( config, container_resolution, @@ -39,6 +41,7 @@ def __init__( email: Optional[str] = None, password: Optional[str] = None, verify: bool = True, + session: requests.sessions.Session() = None, ) -> None: """ A base representation of a connection to a Galaxy instance, identified @@ -106,6 +109,8 @@ def __init__( self.tool_data = tool_data.ToolDataClient(self) self.folders = folders.FoldersClient(self) self.tool_dependencies = tool_dependencies.ToolDependenciesClient(self) + if session is not None: + self.session = session def __repr__(self) -> str: """ diff --git a/bioblend/galaxy/objects/galaxy_instance.py b/bioblend/galaxy/objects/galaxy_instance.py index ddde94831..909eca02f 100644 --- a/bioblend/galaxy/objects/galaxy_instance.py +++ b/bioblend/galaxy/objects/galaxy_instance.py @@ -9,6 +9,8 @@ Optional, ) +import requests + import bioblend import bioblend.galaxy from bioblend.galaxy.datasets import TERMINAL_STATES @@ -57,6 +59,7 @@ def __init__( email: Optional[str] = None, password: Optional[str] = None, verify: bool = True, + session: requests.sessions.Session() = None, ) -> None: self.gi = bioblend.galaxy.GalaxyInstance(url, api_key, email, password, verify) self.log = bioblend.log @@ -68,6 +71,8 @@ def __init__( self.invocations = client.ObjInvocationClient(self) self.tools = client.ObjToolClient(self) self.jobs = client.ObjJobClient(self) + if session is not None: + self.session = session def _wait_datasets( self, datasets: Iterable[wrappers.Dataset], polling_interval: float, break_on_error: bool = True diff --git a/bioblend/galaxyclient.py b/bioblend/galaxyclient.py index 3a6c168e0..4bff70997 100644 --- a/bioblend/galaxyclient.py +++ b/bioblend/galaxyclient.py @@ -38,6 +38,7 @@ def __init__( password: Optional[str] = None, verify: bool = True, timeout: Optional[float] = None, + session: requests.sessions.Session() = None, ) -> None: """ :param verify: Whether to verify the server's TLS certificate @@ -54,11 +55,18 @@ def __init__( for scheme in ("https://", "http://"): log.warning(f"Missing scheme in url, trying with {scheme}") with contextlib.suppress(requests.RequestException): - r = requests.get( - scheme + url, - timeout=self.timeout, - verify=self.verify, - ) + if session is not None: + r = session.get( + scheme + url, + timeout=self.timeout, + verify=self.verify, + ) + else: + r = requests.get( + scheme + url, + timeout=self.timeout, + verify=self.verify, + ) r.raise_for_status() found_scheme = scheme break @@ -83,6 +91,8 @@ def __init__( self._max_get_attempts = 1 # Delay in seconds between subsequent retries. self._get_retry_delay = 10.0 + # Add session to the client if provided + self.session = session @property def max_get_attempts(self) -> int: @@ -133,7 +143,10 @@ def make_get_request(self, url: str, **kwargs: Any) -> requests.Response: headers = self.json_headers kwargs.setdefault("timeout", self.timeout) kwargs.setdefault("verify", self.verify) - r = requests.get(url, headers=headers, **kwargs) + if self.session is not None: + r = self.session.get(url, headers=headers, **kwargs) + else: + r = requests.get(url, headers=headers, **kwargs) return r def make_post_request( @@ -173,16 +186,26 @@ def my_dumps(d: dict) -> dict: data = json.dumps(payload) if payload is not None else None headers = self.json_headers post_params = params - - r = requests.post( - url, - params=post_params, - data=data, - headers=headers, - timeout=self.timeout, - allow_redirects=False, - verify=self.verify, - ) + if self.session is not None: + r = self.session.post( + url, + params=post_params, + data=data, + headers=headers, + timeout=self.timeout, + allow_redirects=False, + verify=self.verify, + ) + else: + r = requests.post( + url, + params=post_params, + data=data, + headers=headers, + timeout=self.timeout, + allow_redirects=False, + verify=self.verify, + ) if r.status_code == 200: try: return r.json() @@ -214,15 +237,26 @@ def make_delete_request( """ data = json.dumps(payload) if payload is not None else None headers = self.json_headers - r = requests.delete( - url, - params=params, - data=data, - headers=headers, - timeout=self.timeout, - allow_redirects=False, - verify=self.verify, - ) + if self.session is not None: + r = self.session.delete( + url, + params=params, + data=data, + headers=headers, + timeout=self.timeout, + allow_redirects=False, + verify=self.verify, + ) + else: + r = requests.delete( + url, + params=params, + data=data, + headers=headers, + timeout=self.timeout, + allow_redirects=False, + verify=self.verify, + ) return r def make_put_request(self, url: str, payload: Optional[dict] = None, params: Optional[dict] = None) -> Any: @@ -236,15 +270,26 @@ def make_put_request(self, url: str, payload: Optional[dict] = None, params: Opt """ data = json.dumps(payload) if payload is not None else None headers = self.json_headers - r = requests.put( - url, - params=params, - data=data, - headers=headers, - timeout=self.timeout, - allow_redirects=False, - verify=self.verify, - ) + if self.session is not None: + r = self.session.put( + url, + params=params, + data=data, + headers=headers, + timeout=self.timeout, + allow_redirects=False, + verify=self.verify, + ) + else: + r = requests.put( + url, + params=params, + data=data, + headers=headers, + timeout=self.timeout, + allow_redirects=False, + verify=self.verify, + ) if r.status_code == 200: try: return r.json() @@ -272,15 +317,26 @@ def make_patch_request(self, url: str, payload: Optional[dict] = None, params: O """ data = json.dumps(payload) if payload is not None else None headers = self.json_headers - r = requests.patch( - url, - params=params, - data=data, - headers=headers, - timeout=self.timeout, - allow_redirects=False, - verify=self.verify, - ) + if self.session is not None: + r = self.session.patch( + url, + params=params, + data=data, + headers=headers, + timeout=self.timeout, + allow_redirects=False, + verify=self.verify, + ) + else: + r = requests.patch( + url, + params=params, + data=data, + headers=headers, + timeout=self.timeout, + allow_redirects=False, + verify=self.verify, + ) if r.status_code == 200: try: return r.json() @@ -355,12 +411,20 @@ def key(self) -> Optional[str]: auth_url = f"{self.url}/authenticate/baseauth" # Use lower level method instead of make_get_request() because we # need the additional Authorization header. - r = requests.get( - auth_url, - headers=headers, - timeout=self.timeout, - verify=self.verify, - ) + if self.session is not None: + r = self.session.get( + auth_url, + headers=headers, + timeout=self.timeout, + verify=self.verify, + ) + else: + r = requests.get( + auth_url, + headers=headers, + timeout=self.timeout, + verify=self.verify, + ) if r.status_code != 200: raise Exception("Failed to authenticate user.") response = r.json() diff --git a/docs/examples/session_handling/cookie_handler.py b/docs/examples/session_handling/cookie_handler.py new file mode 100644 index 000000000..474e9b730 --- /dev/null +++ b/docs/examples/session_handling/cookie_handler.py @@ -0,0 +1,94 @@ +""" +Cookie handler for Authelia authentication using the Galaxy API +""" + +import getpass +import logging as log +import sys +from http.cookiejar import LWPCookieJar +from pathlib import Path +from pprint import pprint + +import requests +from galaxy_api import * + +AUTH_HOSTNAME = "auth.service.org" +API_HOSTNAME = "galaxy.service.org" +cookie_path = Path(".galaxy_auth.txt") +cookie_jar = LWPCookieJar(cookie_path) + + +class ExpiredCookies(Exception): + pass + + +class NoCookies(Exception): + pass + + +def main(): + try: + cookie_jar.load() # raises OSError + if not cookie_jar: # if empty due to expirations + raise ExpiredCookies() + except OSError: + print("No cached session found, please authenticate") + prompt_authentication() + except ExpiredCookies: + print("Session has expired, please authenticate") + prompt_authentication() + run_examples() + + +def prompt_authentication(): + # -------------------------------------------------------------------------- + # Prompt for username and password + + username = input("Please enter username: ") + password = getpass.getpass(f"Please enter password for {username}: ") + + # -------------------------------------------------------------------------- + # Prepare authentication packet and authenticate session using Authelia + + login_body = { + "username": username, + "password": password, + "requestMethod": "GET", + "keepMeLoggedIn": True, + "targetURL": API_HOSTNAME, + } + + with requests.sessions.Session() as session: + session.cookies = cookie_jar + session.verify = True + + auth = session.post(f"https://{AUTH_HOSTNAME}/api/firstfactor", json=login_body) + + response = session.get(f"https://{AUTH_HOSTNAME}/api/user/info") + if response.status_code != 200: + print("Authentication failed") + sys.exit() + else: + pprint(response.json()) + session.cookies.save() + + +def run_examples(): + GALAXY_KEY = "user_api_key" + WORKFLOW_NAME = "workflow_name" + with requests.sessions.Session() as session: + session.cookies = cookie_jar + + print("Running demo to demonstrate how to use the Galaxy API with Authelia") + + print("Getting workflows from Galaxy") + response = get_workflows(f"https://{API_HOSTNAME}", GALAXY_KEY, session=session) + print(response) + + print("Getting inputs for a workflow") + response = get_inputs(f"https://{API_HOSTNAME}", GALAXY_KEY, WORKFLOW_NAME, session=session) + print(response) + + +if __name__ == "__main__": + main() diff --git a/docs/examples/session_handling/galaxy_api.py b/docs/examples/session_handling/galaxy_api.py new file mode 100644 index 000000000..c052fc609 --- /dev/null +++ b/docs/examples/session_handling/galaxy_api.py @@ -0,0 +1,58 @@ +from bioblend.galaxy import GalaxyInstance + + +def get_inputs(server, api_key, workflow_name, session=None): + """ + Function to get an array of inputs for a given galaxy workflow + + Usage: + get_inputs( + server = "galaxy.server.org", + api_key = "user_api_key", + workflow_name = "workflow_name", + ) + + Args: + server (string): Galaxy server address + api_key (string): User generated string from galaxy instance + to create: User > Preferences > Manage API Key > Create a new key + workflow_name (string): Target workflow name + Returns: + inputs (array of strings): Input files expected by the workflow, these will be in the same order as they should be given in the main API call + """ + + gi = GalaxyInstance(url=server, key=api_key, session=session) + api_workflow = gi.workflows.get_workflows(name=workflow_name) + steps = gi.workflows.export_workflow_dict(api_workflow[0]["id"])["steps"] + inputs = [] + for step in steps: + # Some of the steps don't take inputs so have to skip these + if len(steps[step]["inputs"]) > 0: + inputs.append(steps[step]["inputs"][0]["name"]) + + return inputs + + +def get_workflows(server, api_key, session=None): + """ + Function to get an array of workflows available on a given galaxy instance + + Usage: + get_workflows( + server = "galaxy.server.org", + api_key = "user_api_key", + ) + + Args: + server (string): Galaxy server address + api_key (string): User generated string from galaxy instance + to create: User > Preferences > Manage API Key > Create a new key + Returns: + workflows (array of strings): Workflows available to be run on the galaxy instance provided + """ + gi = GalaxyInstance(url=server, key=api_key, session=session) + workflows_dict = gi.workflows.get_workflows() + workflows = [] + for item in workflows_dict: + workflows.append(item["name"]) + return workflows