Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sessions #48

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
sudo: false
language: python
python:
- "2.7"
Expand All @@ -8,4 +7,4 @@ python:
- "3.8"
install: "pip install ."
script: "nosetests"
dist: bionic
dist: bionic
5 changes: 3 additions & 2 deletions sickle/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(self, endpoint,
self.class_mapping = class_mapping or DEFAULT_CLASS_MAP
self.encoding = encoding
self.request_args = request_args
self.session = requests.Session()

def harvest(self, **kwargs): # pragma: no cover
"""Make HTTP requests to the OAI server.
Expand All @@ -134,8 +135,8 @@ def harvest(self, **kwargs): # pragma: no cover

def _request(self, kwargs):
if self.http_method == 'GET':
return requests.get(self.endpoint, params=kwargs, **self.request_args)
return requests.post(self.endpoint, data=kwargs, **self.request_args)
return self.session.get(self.endpoint, params=kwargs, **self.request_args)
return self.session.post(self.endpoint, data=kwargs, **self.request_args)

def ListRecords(self, ignore_deleted=False, **kwargs):
"""Issue a ListRecords request.
Expand Down
3 changes: 2 additions & 1 deletion sickle/tests/test_harvesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from lxml import etree
from nose.tools import raises
from requests import Session
import mock

from sickle import Sickle
Expand Down Expand Up @@ -238,7 +239,7 @@ class TestCaseWrongEncoding(unittest.TestCase):

def __init__(self, methodName='runTest'):
super(TestCaseWrongEncoding, self).__init__(methodName)
self.patch = mock.patch('sickle.app.requests.get', mock_get)
self.patch = mock.patch.object(Session, 'get', mock_get)

def setUp(self):
self.patch.start()
Expand Down
11 changes: 6 additions & 5 deletions sickle/tests/test_sickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mock import patch, Mock
from nose.tools import raises
from requests import HTTPError
from requests import Session

from sickle import Sickle

Expand All @@ -33,7 +34,7 @@ def test_invalid_iterator(self):
def test_pass_request_args(self):
mock_response = Mock(text=u'<xml/>', content='<xml/>', status_code=200)
mock_get = Mock(return_value=mock_response)
with patch('sickle.app.requests.get', mock_get):
with patch.object(Session, 'get', mock_get):
sickle = Sickle('url', timeout=10, proxies=dict(),
auth=('user', 'password'))
sickle.ListRecords()
Expand All @@ -45,7 +46,7 @@ def test_pass_request_args(self):
def test_override_encoding(self):
mock_response = Mock(text='<xml/>', content='<xml/>', status_code=200)
mock_get = Mock(return_value=mock_response)
with patch('sickle.app.requests.get', mock_get):
with patch.object(Session, 'get', mock_get):
sickle = Sickle('url', encoding='encoding')
sickle.ListSets()
mock_get.assert_called_once_with('url',
Expand All @@ -56,7 +57,7 @@ def test_no_retry(self):
headers={'retry-after': '10'},
raise_for_status=Mock(side_effect=HTTPError))
mock_get = Mock(return_value=mock_response)
with patch('sickle.app.requests.get', mock_get):
with patch.object(Session, 'get', mock_get):
sickle = Sickle('url')
try:
sickle.ListRecords()
Expand All @@ -71,7 +72,7 @@ def test_retry_on_503(self):
mock_get = Mock(return_value=mock_response)
sleep_mock = Mock()
with patch('time.sleep', sleep_mock):
with patch('sickle.app.requests.get', mock_get):
with patch.object(Session, 'get', mock_get):
sickle = Sickle('url', max_retries=3, default_retry_after=0)
try:
sickle.ListRecords()
Expand All @@ -87,7 +88,7 @@ def test_retry_on_custom_code(self):
mock_response = Mock(status_code=500,
raise_for_status=Mock(side_effect=HTTPError))
mock_get = Mock(return_value=mock_response)
with patch('sickle.app.requests.get', mock_get):
with patch.object(Session, 'get', mock_get):
sickle = Sickle('url', max_retries=3, default_retry_after=0, retry_status_codes=(503, 500))
try:
sickle.ListRecords()
Expand Down