diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9011e6a..8c8db8c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,10 +2,10 @@ name: Test on: push: - branches: + branches: - master pull_request: - branches: + branches: - master types: [opened, synchronize, reopened] @@ -20,17 +20,18 @@ jobs: python-version: '3.10' cache: 'pip' - name: Install deps - run: pip install -r requirements.txt + run: pip install -r requirements.txt - name: Run tests run: python -m coverage run -m pytest env: CLIENT_ID: ${{ secrets.CLIENT_ID }} USERNAME: ${{ secrets.USERNAME }} - USERNAME2: ${{ secrets.USERNAME2 }} PASSWORD: ${{ secrets.PASSWORD }} OTPCODE: ${{ secrets.OTPCODE }} USER_POOL_ID: ${{ secrets.USER_POOL_ID }} - FEDERATED_POOL_ID: ${{ secrets.FEDERATED_POOL_ID }} + IDENTITY_POOL_ID: ${{ secrets.IDENTITY_POOL_ID }} + API_URL: ${{ secrets.API_URL }} + REGION: ${{ secrets.REGION }} - name: Publish code coverage uses: paambaati/codeclimate-action@v5.0.0 env: diff --git a/requirements.txt b/requirements.txt index 9d7d544..abd0ea5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ boto3==1.29.3 botocore==1.32.3 certifi==2023.11.17 charset-normalizer==3.3.2 +colorama==0.4.6 coverage==7.3.2 docutils==0.20.1 idna==3.4 @@ -18,20 +19,22 @@ nh3==0.2.14 packaging==23.2 pkginfo==1.9.6 pluggy==1.3.0 -Pygments==2.17.0 +Pygments==2.17.1 pytest==7.4.3 pytest-cov==4.1.0 python-dateutil==2.8.2 +python-dotenv==1.0.0 +pywin32-ctypes==0.2.2 readme-renderer==42.0 requests==2.31.0 requests-toolbelt==1.0.0 rfc3986==2.0.0 rich==13.7.0 s3transfer==0.7.0 -scribeauth==1.0.8 +scribeauth==1.1.1 six==1.16.0 twine==4.0.2 typing==3.7.4.3 typing_extensions==4.8.0 -urllib3==2.0.7 +urllib3>=2.0.7,<=2.1.0 zipp==3.17.0 diff --git a/scribemi/ScribeMi.py b/scribemi/ScribeMi.py index 9216a86..fd7cc58 100644 --- a/scribemi/ScribeMi.py +++ b/scribemi/ScribeMi.py @@ -1,10 +1,12 @@ import requests import json from io import BytesIO -from typing import BinaryIO, Optional, TypedDict, Union +from typing import BinaryIO, Union from scribeauth import ScribeAuth from aws_requests_auth.aws_auth import AWSRequestsAuth from datetime import datetime +from typing_extensions import TypedDict, Optional + class Env(TypedDict): API_URL: str @@ -13,22 +15,31 @@ class Env(TypedDict): CLIENT_ID: str REGION: str -class MITask(TypedDict): + +class MITaskBase(TypedDict): jobid: str client: str + status: str + submitted: int + + +class MITask(MITaskBase, total=False): companyName: Optional[str] clientFilename: Optional[str] originalFilename: Optional[str] clientModelFilename: Optional[str] - status: str - submitted: int modelUrl: Optional[str] -class SubmitTaskParams(TypedDict): + +class SubmitTaskParamsBase(TypedDict): filetype: str + + +class SubmitTaskParams(SubmitTaskParamsBase, total=False): filename: Optional[str] companyname: Optional[str] + class UnauthenticatedException(Exception): """ Exception raised when current token: @@ -36,23 +47,29 @@ class UnauthenticatedException(Exception): - Is wrong - Has expired and needs to be updated """ + pass + class TaskNotFoundException(Exception): """ Exception raised when trying to access an unexistent task. """ + pass + class InvalidFiletypeException(Exception): """ Exception raised when trying to upload a file with a wrong filetype. - + Accepted values: 'pdf', 'xlsx', 'xls', 'xlsm', 'doc', 'docx', 'ppt', 'pptx' """ + pass + class MI: def __init__(self, env): """ @@ -64,52 +81,72 @@ def __init__(self, env): api_key -- The api key for the application. """ self.env = env - self.auth_client = ScribeAuth({ - 'client_id': env['CLIENT_ID'], - 'user_pool_id': env['USER_POOL_ID'], - 'identity_pool_id': env['IDENTITY_POOL_ID'] - }) + self.auth_client = ScribeAuth( + { + "client_id": env["CLIENT_ID"], + "user_pool_id": env["USER_POOL_ID"], + "identity_pool_id": env["IDENTITY_POOL_ID"], + } + ) self.tokens = None self.user_id = None self.request_auth = None def authenticate(self, param): self.tokens = self.auth_client.get_tokens(**param) - self.user_id = self.auth_client.get_federated_id(self.tokens['id_token']) - self.credentials = self.auth_client.get_federated_credentials(self.user_id, self.tokens['id_token']) - host = self.env['API_URL'].split('/')[0] - self.request_auth = AWSRequestsAuth( - aws_access_key=self.credentials['AccessKeyId'], - aws_secret_access_key=self.credentials['SecretKey'], - aws_token=self.credentials['SessionToken'], - aws_host=host, - aws_region=self.env['REGION'], - aws_service='execute-api' - ) + id_token = self.tokens.get("id_token") + if id_token != None: + self.user_id = self.auth_client.get_federated_id(id_token) + self.credentials = self.auth_client.get_federated_credentials( + self.user_id, id_token + ) + host = self.env["API_URL"].split("/")[0] + self.request_auth = AWSRequestsAuth( + aws_access_key=self.credentials["AccessKeyId"], + aws_secret_access_key=self.credentials["SecretKey"], + aws_token=self.credentials["SessionToken"], + aws_host=host, + aws_region=self.env["REGION"], + aws_service="execute-api", + ) + else: + raise UnauthenticatedException("Authentication failed") def reauthenticate(self): if self.tokens == None or self.user_id == None: - raise UnauthenticatedException('Must authenticate before reauthenticating') - self.tokens = self.auth_client.get_tokens(refresh_token=self.tokens['refresh_token']) - self.credentials = self.auth_client.get_federated_credentials(self.user_id, self.tokens['id_token']) - host = self.env['API_URL'].split('/')[0] - self.request_auth = AWSRequestsAuth( - aws_access_key=self.credentials['AccessKeyId'], - aws_secret_access_key=self.credentials['SecretKey'], - aws_token=self.credentials['SessionToken'], - aws_host=host, - aws_region=self.env['REGION'], - aws_service='execute-api' - ) + raise UnauthenticatedException("Must authenticate before reauthenticating") + refresh_token = self.tokens.get("refresh_token") + if refresh_token != None: + self.tokens = self.auth_client.get_tokens(refresh_token=refresh_token) + id_token = self.tokens.get("id_token") + if id_token != None: + self.credentials = self.auth_client.get_federated_credentials( + self.user_id, id_token + ) + host = self.env["API_URL"].split("/")[0] + self.request_auth = AWSRequestsAuth( + aws_access_key=self.credentials["AccessKeyId"], + aws_secret_access_key=self.credentials["SecretKey"], + aws_token=self.credentials["SessionToken"], + aws_host=host, + aws_region=self.env["REGION"], + aws_service="execute-api", + ) + else: + raise UnauthenticatedException("Authentication failed") + else: + raise UnauthenticatedException("Authentication failed") def call_endpoint(self, method, path, data=None, params=None): if self.request_auth == None: - raise UnauthenticatedException('Not authenticated') - if self.credentials['Expiration'] < datetime.now(self.credentials['Expiration'].tzinfo): + raise UnauthenticatedException("Not authenticated") + if self.credentials["Expiration"] < datetime.now( + self.credentials["Expiration"].tzinfo + ): self.reauthenticate() res = requests.request( method=method, - url='https://{host}{path}'.format(host=self.env['API_URL'], path=path), + url="https://{host}{path}".format(host=self.env["API_URL"], path=path), params=params, json=data, auth=self.request_auth, @@ -117,71 +154,77 @@ def call_endpoint(self, method, path, data=None, params=None): if res.status_code == 200: return json.loads(res.text) elif res.status_code == 401 or res.status_code == 403: - raise UnauthenticatedException('Authentication failed ({})'.format(res.status_code)) + raise UnauthenticatedException( + "Authentication failed ({})".format(res.status_code) + ) elif res.status_code == 404: - raise TaskNotFoundException('Not found') + raise TaskNotFoundException("Not found") else: - raise Exception('Unexpected error ({})'.format(res.status_code)) + raise Exception("Unexpected error ({})".format(res.status_code)) def list_tasks(self, companyName=None) -> list[MITask]: - params = { - 'includePresigned': True - } + params = {"includePresigned": True} if companyName != None: - params['company'] = companyName - return self.call_endpoint('GET', '/tasks', params=params).get('tasks') + params["company"] = companyName + return self.call_endpoint("GET", "/tasks", params=params).get("tasks") def get_task(self, jobid: str) -> MITask: - return self.call_endpoint('GET', '/tasks/{}'.format(jobid)) + return self.call_endpoint("GET", "/tasks/{}".format(jobid)) def fetch_model(self, task: MITask): - if task.modelUrl == None: - raise Exception('Cannot load model for task {}: model is not ready for export'.format(task.jobid)) - res = requests.get(task.modelUrl) - if res.status == 200: + modelUrl = task.get("modelUrl") + if modelUrl == None: + raise Exception( + "Cannot load model for task {}: model is not ready for export".format( + task.get("jobid") + ) + ) + res = requests.get(modelUrl) + if res.status_code == 200: return json.loads(res.text) elif res.status_code == 401 or res.status_code == 403: raise UnauthenticatedException( - '{} Authentication failed (possibly due to timeout: try calling get_task immediately before fetch_model)' - .format(res.status_code) + "{} Authentication failed (possibly due to timeout: try calling get_task immediately before fetch_model)".format( + res.status_code + ) ) else: - raise Exception('Unexpected error ({})'.format(res.status_code)) + raise Exception("Unexpected error ({})".format(res.status_code)) def consolidate_tasks(self, tasks: list[MITask]): - jobids = list( - map( - lambda task: task.jobid, - tasks - ) - ) - jobids_param = ';'.join(jobids) - res = self.call_endpoint('GET', '/fundportfolio?jobids={}'.format(jobids_param)) + jobids = list(map(lambda task: task.get("jobid"), tasks)) + jobids_param = ";".join(jobids) + res = self.call_endpoint("GET", "/fundportfolio?jobids={}".format(jobids_param)) return res.model - def submit_task(self, file_or_filename: Union[str, BytesIO, BinaryIO], params: SubmitTaskParams): - filetype_list = ['pdf', 'xlsx', 'xls', 'xlsm', 'doc', 'docx', 'ppt', 'pptx'] - if params.get('filetype') not in filetype_list: - raise InvalidFiletypeException("Invalid filetype. Accepted values: 'pdf', 'xlsx', 'xls', 'xlsm', 'doc', 'docx', 'ppt', 'pptx'.") + def submit_task( + self, file_or_filename: Union[str, BytesIO, BinaryIO], params: SubmitTaskParams + ): + filetype_list = ["pdf", "xlsx", "xls", "xlsm", "doc", "docx", "ppt", "pptx"] + if params.get("filetype") not in filetype_list: + raise InvalidFiletypeException( + "Invalid filetype. Accepted values: 'pdf', 'xlsx', 'xls', 'xlsm', 'doc', 'docx', 'ppt', 'pptx'." + ) - if isinstance(file_or_filename, str) and params.get('filename') == None: - params['filename'] = file_or_filename + if isinstance(file_or_filename, str) and params.get("filename") == None: + params["filename"] = file_or_filename - post_res = self.call_endpoint('POST', '/tasks', params) - put_url = post_res['url'] + post_res = self.call_endpoint("POST", "/tasks", params) + put_url = post_res["url"] if isinstance(file_or_filename, str): - with open(file_or_filename, 'rb') as file: + with open(file_or_filename, "rb") as file: upload_file(file, put_url) else: return upload_file(file_or_filename, put_url) - return post_res['jobid'] + return post_res["jobid"] def delete_task(self, task: MITask): - return self.call_endpoint('DELETE', '/tasks/{}'.format(task['jobid'])) + return self.call_endpoint("DELETE", "/tasks/{}".format(task["jobid"])) + def upload_file(file, url): res = requests.put(url, data=file) if res.status_code != 200: - raise Exception('Error uploading file: {}'.format(res.status_code)) + raise Exception("Error uploading file: {}".format(res.status_code)) diff --git a/scribemi/__init__.py b/scribemi/__init__.py index 00e75dd..2a8d7b4 100644 --- a/scribemi/__init__.py +++ b/scribemi/__init__.py @@ -1 +1 @@ -from .ScribeMi import MI \ No newline at end of file +from .ScribeMi import MI diff --git a/tests/test_scribemi.py b/tests/test_scribemi.py index e82dda7..1b4e65d 100644 --- a/tests/test_scribemi.py +++ b/tests/test_scribemi.py @@ -13,18 +13,19 @@ """ env = { - 'API_URL': os.environ['API_URL'], - 'IDENTITY_POOL_ID': os.environ['IDENTITY_POOL_ID'], - 'USER_POOL_ID': os.environ['USER_POOL_ID'], - 'CLIENT_ID': os.environ['CLIENT_ID'], - 'REGION': os.environ['REGION'], + "API_URL": os.environ["API_URL"], + "IDENTITY_POOL_ID": os.environ["IDENTITY_POOL_ID"], + "USER_POOL_ID": os.environ["USER_POOL_ID"], + "CLIENT_ID": os.environ["CLIENT_ID"], + "REGION": os.environ["REGION"], } username_and_password = { - 'username': os.environ['USERNAME'], - 'password': os.environ['PASSWORD'], + "username": os.environ["USERNAME"], + "password": os.environ["PASSWORD"], } + class TestScribeMiAuth(unittest.TestCase): def test_fetch_credentials_with_username_and_password(self): client = MI(env) @@ -33,7 +34,15 @@ def test_fetch_credentials_with_username_and_password(self): def test_fetch_credentials_with_refresh_token(self): client = MI(env) client.authenticate(username_and_password) - client.authenticate({ 'refresh_token': client.tokens['refresh_token'] }) + tokens = client.tokens + if tokens != None: + refresh_token = tokens.get("refresh_token") + if refresh_token != None: + client.authenticate({"refresh_token": refresh_token}) + else: + assert False + else: + assert False def test_reauthenticate(self): client = MI(env) @@ -53,27 +62,33 @@ def test_throws_if_not_authenticated(self): except UnauthenticatedException: assert True + class TestScribeMiErrorHandling(unittest.TestCase): def test_throws_on_error_response(self): client = MI(env) client.authenticate(username_and_password) try: - client.get_task('invalidJobid') + client.get_task("invalidJobid") assert False except: assert True + class TestScribeMiEndpoints(unittest.TestCase): def test_endpoints(self): client = MI(env) client.authenticate(username_and_password) - jobid = client.submit_task('tests/companies_house_document.pdf', { - 'filetype': 'pdf', - }) + jobid = client.submit_task( + "tests/companies_house_document.pdf", + {"filetype": "pdf"}, + ) client.list_tasks() + if jobid == None: + assert False + task = client.get_task(jobid) try: