Skip to content

Commit

Permalink
Merge pull request #34 from IT-CEREBRUM/use-http-utils-in-dfo-client
Browse files Browse the repository at this point in the history
Refactor the dfo client to use `Cerebrum.utils.http`
  • Loading branch information
fredrikhl authored and GitHub Enterprise committed Jan 12, 2023
2 parents 29ade88 + 723c400 commit 2ee3d6d
Showing 1 changed file with 19 additions and 67 deletions.
86 changes: 19 additions & 67 deletions Cerebrum/modules/no/dfo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,67 +31,18 @@

import requests
import six
from six.moves.urllib.parse import (
quote_plus,
urlparse,
urljoin as _urljoin,
)

from Cerebrum.config import loader
from Cerebrum.config.configuration import (Configuration,
ConfigDescriptor,
Namespace)
from Cerebrum.config.secrets import Secret, get_secret_from_string
from Cerebrum.config.settings import Boolean, Iterable, String
from Cerebrum.utils import http as http_utils

logger = logging.getLogger(__name__)


def quote_path_arg(arg):
return quote_plus(str(arg))


def merge_dicts(*dicts):
"""
Combine a series of dicts without mutating any of them.
>>> merge_dicts({'a': 1}, {'b': 2})
{'a': 1, 'b': 2}
>>> merge_dicts({'a': 1}, {'a': 2})
{'a': 2}
>>> merge_dicts(None, None, None)
{}
"""
combined = dict()
for d in dicts:
if not d:
continue
for k in d:
combined[k] = d[k]
return combined


def urljoin(base_url, *paths):
"""
A sane urljoin.
Note how urllib.parse.urljoin will assume 'relative to parent' when the
base_url doesn't end with a '/':
>>> urllib.parse.urljoin('https://localhost/foo', 'bar')
'https://localhost/bar'
>>> urljoin('https://localhost/foo', 'bar')
'https://localhost/foo/bar'
>>> urljoin('https://localhost/foo', 'bar', 'baz')
'https://localhost/foo/bar/baz'
"""
for path in paths:
base_url = _urljoin(base_url.rstrip('/') + '/', path)
return base_url


class SapEndpoints(object):
"""Get endpoints relative to the SAP API URL."""

Expand Down Expand Up @@ -119,16 +70,22 @@ def __repr__(self):
).format(cls=type(self), obj=self)

def get_employee(self, employee_id):
return urljoin(self.baseurl, self.employee_path,
quote_path_arg(employee_id))
return http_utils.urljoin(
self.baseurl,
self.employee_path,
http_utils.safe_path(employee_id))

def get_orgenhet(self, org_id):
return urljoin(self.baseurl, self.orgenhet_path,
quote_path_arg(org_id))
return http_utils.urljoin(
self.baseurl,
self.orgenhet_path,
http_utils.safe_path(org_id))

def get_stilling(self, stilling_id):
return urljoin(self.baseurl, self.stilling_path,
quote_path_arg(stilling_id))
return http_utils.urljoin(
self.baseurl,
self.stilling_path,
http_utils.safe_path(stilling_id))


class SapClient(object):
Expand Down Expand Up @@ -163,7 +120,7 @@ def __init__(self,
orgenhet_path=orgenhet_path,
stilling_path=stilling_path,
)
self.headers = merge_dicts(self.default_headers, headers)
self.headers = http_utils.merge_headers(self.default_headers, headers)
self.api_headers = {
'employee': employee_headers,
'orgenhet': orgenhet_headers,
Expand All @@ -190,13 +147,8 @@ def call(self,
params=None,
return_response=True,
**kwargs):
headers = merge_dicts(self.headers, headers)
if params is None:
params = {}
logger.debug('Calling %s %s with params=%r',
method_name,
urlparse(url).path,
params)
headers = http_utils.merge_headers(self.headers, headers)
params = params or {}
r = self.session.request(method_name,
url,
headers=headers,
Expand All @@ -220,7 +172,7 @@ def put(self, url, **kwargs):
# def get_employee(self, employee_id: str) -> [None, dict]:
def get_employee(self, employee_id):
url = self.urls.get_employee(employee_id)
headers = merge_dicts(self.headers, self.api_headers['employee'])
headers = self.api_headers['employee']
response = self.get(url, headers=headers)
if not self._is_api_response(response):
response.raise_for_status()
Expand All @@ -234,7 +186,7 @@ def get_employee(self, employee_id):
# def get_orgenhet(self, org_id: str) -> [None, dict]:
def get_orgenhet(self, org_id):
url = self.urls.get_orgenhet(org_id)
headers = merge_dicts(self.headers, self.api_headers['organisasjonId'])
headers = self.api_headers['organisasjonId']
response = self.get(url, headers=headers)
if response.status_code == 404:
return None
Expand All @@ -246,7 +198,7 @@ def get_orgenhet(self, org_id):
# def get_stilling(self, stilling_id: str) -> [None, dict]:
def get_stilling(self, stilling_id):
url = self.urls.get_stilling(stilling_id)
headers = merge_dicts(self.headers, self.api_headers['stilling'])
headers = self.api_headers['stilling']
response = self.get(url, headers=headers)
if response.status_code == 404:
return None
Expand Down

0 comments on commit 2ee3d6d

Please sign in to comment.