diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 2338a5f..718db1e 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -13,103 +13,109 @@ AuthenticationError, ForbiddenError, InvalidRequestError, NotFoundError, ServiceUnavailableError) +# global session +session = None -class Connection: - @classmethod - def request(cls, http_verb, url, **options): - if 'headers' in options: - headers = options['headers'] + +def request(http_verb, url, **options): + if 'headers' in options: + headers = options['headers'] + else: + headers = {} + + accept_value = 'application/json' + if ApiConfig.api_version: + accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version + + headers = Util.merge_to_dicts({ + 'accept': accept_value, + 'request-source': 'python', + 'request-source-version': VERSION + }, headers) + if ApiConfig.api_key: + headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) + + options['headers'] = headers + + abs_url = '%s/%s' % (ApiConfig.api_base, url) + + return execute_request(http_verb, abs_url, **options) + + +def execute_request(http_verb, url, **options): + session = get_session() + + try: + response = session.request( + method=http_verb, + url=url, + verify=ApiConfig.verify_ssl, + **options + ) + if response.status_code < 200 or response.status_code >= 300: + handle_api_error(response) else: - headers = {} - - accept_value = 'application/json' - if ApiConfig.api_version: - accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version - - headers = Util.merge_to_dicts({'accept': accept_value, - 'request-source': 'python', - 'request-source-version': VERSION}, headers) - if ApiConfig.api_key: - headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) - - options['headers'] = headers - - abs_url = '%s/%s' % (ApiConfig.api_base, url) - - return cls.execute_request(http_verb, abs_url, **options) - - @classmethod - def execute_request(cls, http_verb, url, **options): - session = cls.get_session() - - try: - response = session.request(method=http_verb, - url=url, - verify=ApiConfig.verify_ssl, - **options) - if response.status_code < 200 or response.status_code >= 300: - cls.handle_api_error(response) - else: - return response - except requests.exceptions.RequestException as e: - if e.response: - cls.handle_api_error(e.response) - raise e - - @classmethod - def get_session(cls): + return response + except requests.exceptions.RequestException as e: + if e.response: + handle_api_error(e.response) + raise e + + +def get_retries(): + if not ApiConfig.use_retries: + return Retry(total=0) + + Retry.BACKOFF_MAX = ApiConfig.max_wait_between_retries + retries = Retry(total=ApiConfig.number_of_retries, + connect=ApiConfig.number_of_retries, + read=ApiConfig.number_of_retries, + status_forcelist=ApiConfig.retry_status_codes, + backoff_factor=ApiConfig.retry_backoff_factor, + raise_on_status=False) + + return retries + + +def get_session(): + global session + if session is None: session = requests.Session() - adapter = HTTPAdapter(max_retries=cls.get_retries()) + adapter = HTTPAdapter(max_retries=get_retries()) session.mount(ApiConfig.api_protocol, adapter) - - return session - - @classmethod - def get_retries(cls): - if not ApiConfig.use_retries: - return Retry(total=0) - - Retry.BACKOFF_MAX = ApiConfig.max_wait_between_retries - retries = Retry(total=ApiConfig.number_of_retries, - connect=ApiConfig.number_of_retries, - read=ApiConfig.number_of_retries, - status_forcelist=ApiConfig.retry_status_codes, - backoff_factor=ApiConfig.retry_backoff_factor, - raise_on_status=False) - - return retries - - @classmethod - def parse(cls, response): - try: - return response.json() - except ValueError: - raise DataLinkError(http_status=response.status_code, http_body=response.text) - - @classmethod - def handle_api_error(cls, resp): - error_body = cls.parse(resp) - - # if our app does not form a proper data_link_error response - # throw generic error - if 'error' not in error_body: - raise DataLinkError(http_status=resp.status_code, http_body=resp.text) - - code = error_body['error']['code'] - message = error_body['error']['message'] - prog = re.compile('^QE([a-zA-Z])x') - if prog.match(code): - code_letter = prog.match(code).group(1) - - d_klass = { - 'L': LimitExceededError, - 'M': InternalServerError, - 'A': AuthenticationError, - 'P': ForbiddenError, - 'S': InvalidRequestError, - 'C': NotFoundError, - 'X': ServiceUnavailableError - } - klass = d_klass.get(code_letter, DataLinkError) - - raise klass(message, resp.status_code, resp.text, resp.headers, code) + return session + + +def parse(response): + try: + return response.json() + except ValueError: + raise DataLinkError(http_status=response.status_code, http_body=response.text) + + +def handle_api_error(resp): + error_body = parse(resp) + + # if our app does not form a proper data_link_error response + # throw generic error + if 'error' not in error_body: + raise DataLinkError(http_status=resp.status_code, http_body=resp.text) + + code = error_body['error']['code'] + message = error_body['error']['message'] + prog = re.compile('^QE([a-zA-Z])x') + if prog.match(code): + code_letter = prog.match(code).group(1) + + d_klass = { + 'L': LimitExceededError, + 'M': InternalServerError, + 'A': AuthenticationError, + 'P': ForbiddenError, + 'S': InvalidRequestError, + 'C': NotFoundError, + 'X': ServiceUnavailableError + } + klass = d_klass.get(code_letter, DataLinkError) + + raise klass(message, resp.status_code, resp.text, resp.headers, code) diff --git a/nasdaqdatalink/model/database.py b/nasdaqdatalink/model/database.py index 870dedc..fbf9e73 100644 --- a/nasdaqdatalink/model/database.py +++ b/nasdaqdatalink/model/database.py @@ -4,7 +4,7 @@ import nasdaqdatalink.model.dataset from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation @@ -43,7 +43,7 @@ def bulk_download_to_file(self, file_or_folder_path, **options): path_url = self._bulk_download_path() options['stream'] = True - r = Connection.request('get', path_url, **options) + r = connection.request('get', path_url, **options) file_path = file_or_folder_path if os.path.isdir(file_or_folder_path): file_path = file_or_folder_path + '/' + os.path.basename(urlparse(r.url).path) diff --git a/nasdaqdatalink/model/datatable.py b/nasdaqdatalink/model/datatable.py index 2edadb8..935590e 100644 --- a/nasdaqdatalink/model/datatable.py +++ b/nasdaqdatalink/model/datatable.py @@ -3,7 +3,7 @@ from six.moves.urllib.request import urlopen -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation @@ -51,7 +51,7 @@ def _request_file_info(self, file_or_folder_path, **options): updated_options = Util.convert_options(request_type=request_type, **options) - r = Connection.request(request_type, url, **updated_options) + r = connection.request(request_type, url, **updated_options) response_data = r.json() diff --git a/nasdaqdatalink/operations/get.py b/nasdaqdatalink/operations/get.py index 8f93b95..efe3a26 100644 --- a/nasdaqdatalink/operations/get.py +++ b/nasdaqdatalink/operations/get.py @@ -1,7 +1,7 @@ from inflection import singularize from .operation import Operation -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.util import Util @@ -21,7 +21,7 @@ def __get_raw_data__(self): path = Util.constructed_path(cls.get_path(), options['params']) - r = Connection.request('get', path, **options) + r = connection.request('get', path, **options) response_data = r.json() Util.convert_to_dates(response_data) self._raw_data = response_data[singularize(cls.lookup_key())] diff --git a/nasdaqdatalink/operations/list.py b/nasdaqdatalink/operations/list.py index 6aa020a..6e94e78 100644 --- a/nasdaqdatalink/operations/list.py +++ b/nasdaqdatalink/operations/list.py @@ -1,5 +1,5 @@ from .operation import Operation -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.util import Util from nasdaqdatalink.model.paginated_list import PaginatedList from nasdaqdatalink.utils.request_type_util import RequestType @@ -12,7 +12,7 @@ def all(cls, **options): if 'params' not in options: options['params'] = {} path = Util.constructed_path(cls.list_path(), options['params']) - r = Connection.request('get', path, **options) + r = connection.request('get', path, **options) response_data = r.json() Util.convert_to_dates(response_data) resource = cls.create_list_from_response(response_data) @@ -27,7 +27,7 @@ def page(cls, datatable, **options): updated_options = Util.convert_options(request_type=request_type, **options) - r = Connection.request(request_type, path, **updated_options) + r = connection.request(request_type, path, **updated_options) response_data = r.json() Util.convert_to_dates(response_data) diff --git a/test/test_connection.py b/test/test_connection.py index 96d8380..384d1e0 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -1,4 +1,4 @@ -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.api_config import ApiConfig from nasdaqdatalink.errors.data_link_error import ( DataLinkError, LimitExceededError, InternalServerError, @@ -42,7 +42,7 @@ def test_nasdaqdatalink_exceptions_no_retries(self, request_method): for expected_error in data_link_errors: self.assertRaises( - expected_error[2], lambda: Connection.request(request_method, 'databases')) + expected_error[2], lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) def test_parse_error(self, request_method): @@ -51,7 +51,7 @@ def test_parse_error(self, request_method): "https://data.nasdaq.com/api/v3/databases", body="not json", status=500) self.assertRaises( - DataLinkError, lambda: Connection.request(request_method, 'databases')) + DataLinkError, lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) def test_non_data_link_error(self, request_method): @@ -62,16 +62,16 @@ def test_non_data_link_error(self, request_method): {'foobar': {'code': 'blah', 'message': 'something went wrong'}}), status=500) self.assertRaises( - DataLinkError, lambda: Connection.request(request_method, 'databases')) + DataLinkError, lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) - @patch('nasdaqdatalink.connection.Connection.execute_request') + @patch('nasdaqdatalink.connection.execute_request') def test_build_request(self, request_method, mock): ApiConfig.api_key = 'api_token' ApiConfig.api_version = '2015-04-09' params = {'per_page': 10, 'page': 2} headers = {'x-custom-header': 'header value'} - Connection.request(request_method, 'databases', headers=headers, params=params) + connection.request(request_method, 'databases', headers=headers, params=params) expected = call(request_method, 'https://data.nasdaq.com/api/v3/databases', headers={'x-custom-header': 'header value', 'x-api-token': 'api_token', @@ -81,3 +81,15 @@ def test_build_request(self, request_method, mock): 'request-source-version': VERSION}, params={'per_page': 10, 'page': 2}) self.assertEqual(mock.call_args, expected) + + def test_session_reuse(self): + session1 = connection.get_session() + session2 = connection.get_session() + areSessionsSame = session1 is session2 + + adapter1 = connection.get_session().get_adapter(ApiConfig.api_protocol) + adapter2 = connection.get_session().get_adapter(ApiConfig.api_protocol) + areAdaptersSame = adapter1 is adapter2 + + self.assertEqual(areAdaptersSame, True) + self.assertEqual(areSessionsSame, True) diff --git a/test/test_data.py b/test/test_data.py index 7852dbe..53817a1 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -77,7 +77,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection(self, mock): Data.all(params={'database_code': 'NSE', 'dataset_code': 'OIL'}) expected = call('get', 'datasets/NSE/OIL/data', params={}) diff --git a/test/test_database.py b/test/test_database.py index bbae558..7b38b98 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -7,7 +7,7 @@ from six.moves.urllib.parse import parse_qs, urlparse from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import (InternalServerError, DataLinkError) from nasdaqdatalink.model.database import Database from test.factories.database import DatabaseFactory @@ -34,7 +34,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_database_calls_connection(self, mock): database = Database('NSE') database.data_fields() @@ -80,7 +80,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_databases_calls_connection(self, mock): Database.all() expected = call('get', 'databases', params={}) @@ -148,7 +148,7 @@ def test_get_bulk_download_url_without_download_type(self): def test_bulk_download_to_fileaccepts_download_type(self): m = mock_open() - with patch.object(Connection, 'request') as mock_method: + with patch.object(connection, 'request') as mock_method: mock_method.return_value.url = 'https://www.blah.com/download/db.zip' with patch('nasdaqdatalink.model.database.open', m, create=True): self.database.bulk_download_to_file( diff --git a/test/test_dataset.py b/test/test_dataset.py index c44ea65..aed9b8a 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -30,7 +30,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_dataset_calls_connection(self, mock): d = Dataset('NSE/OIL') d.data_fields() @@ -84,7 +84,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datasets_calls_connection(self, mock): Dataset.all() expected = call('get', 'datasets', params={}) diff --git a/test/test_datatable.py b/test/test_datatable.py index ab80194..ff5525b 100644 --- a/test/test_datatable.py +++ b/test/test_datatable.py @@ -37,26 +37,26 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_metadata_calls_connection(self, mock): Datatable('ZACKS/FC').data_fields() expected = call('get', 'datatables/ZACKS/FC/metadata', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_data_calls_connection_with_no_params_for_get_request(self, mock): Datatable('ZACKS/FC').data() expected = call('get', 'datatables/ZACKS/FC', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_data_calls_connection_with_no_params_for_post_request(self, mock): RequestType.USE_GET_REQUEST = False Datatable('ZACKS/FC').data() expected = call('post', 'datatables/ZACKS/FC', json={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_calls_connection_with_params_for_get_request(self, mock): params = {'ticker': ['AAPL', 'MSFT'], 'per_end_date': {'gte': '2015-01-01'}, @@ -76,7 +76,7 @@ def test_datatable_calls_connection_with_params_for_get_request(self, mock): expected = call('get', 'datatables/ZACKS/FC', params=expected_params) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_calls_connection_with_params_for_post_request(self, mock): RequestType.USE_GET_REQUEST = False params = {'ticker': ['AAPL', 'MSFT'], diff --git a/test/test_datatable_data.py b/test/test_datatable_data.py index 7ba53f3..b837fc9 100644 --- a/test/test_datatable_data.py +++ b/test/test_datatable_data.py @@ -83,7 +83,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection_get(self, mock): datatable = Datatable('ZACKS/FC') Data.page(datatable, params={'ticker': ['AAPL', 'MSFT'], @@ -95,7 +95,7 @@ def test_data_calls_connection_get(self, mock): 'qopts.columns[]': ['ticker', 'per_end_date']}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection_post(self, mock): RequestType.USE_GET_REQUEST = False datatable = Datatable('ZACKS/FC') diff --git a/test/test_get.py b/test/test_get.py index 950c5c5..66f2ba6 100644 --- a/test/test_get.py +++ b/test/test_get.py @@ -8,7 +8,7 @@ from nasdaqdatalink.model.merged_dataset import MergedDataset from nasdaqdatalink.get import get from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection class GetSingleDatasetTest(unittest.TestCase): @@ -36,8 +36,8 @@ def test_returns_numpys_when_requested(self): self.assertIsInstance(result, numpy.core.records.recarray) def test_setting_api_key_config(self): - mock_connection = Mock(wraps=Connection) - with patch('nasdaqdatalink.connection.Connection.execute_request', + mock_connection = Mock(wraps=connection) + with patch('nasdaqdatalink.connection.execute_request', new=mock_connection.execute_request) as mock: ApiConfig.api_key = 'api_key_configured' get('NSE/OIL') diff --git a/test/test_get_point_in_time_data.py b/test/test_get_point_in_time_data.py index 8fa57f7..ef0b236 100644 --- a/test/test_get_point_in_time_data.py +++ b/test/test_get_point_in_time_data.py @@ -27,7 +27,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_point_in_time_returns_data_frame_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_point_in_time( @@ -36,7 +36,7 @@ def test_get_point_in_time_returns_data_frame_object(self, mock): self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='asofdate', date='2020-01-01') @@ -44,7 +44,7 @@ def test_asofdate_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -54,7 +54,7 @@ def test_asofdate_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_without_date(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='asofdate') @@ -62,7 +62,7 @@ def test_asofdate_call_without_date(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -75,7 +75,7 @@ def test_from_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -90,7 +90,7 @@ def test_from_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -103,7 +103,7 @@ def test_between_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -118,7 +118,7 @@ def test_between_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_interval_connection(self, mock): self.assertRaises(InvalidRequestError, lambda: nasdaqdatalink.get_point_in_time('ZACKS/FC')) self.assertRaises( @@ -126,7 +126,7 @@ def test_invalid_interval_connection(self, mock): lambda: nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='nasdaqdatalink') ) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_from_connection(self, mock): self.assertRaises( InvalidRequestError, @@ -145,7 +145,7 @@ def test_invalid_from_connection(self, mock): ) ) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_between_connection(self, mock): self.assertRaises( InvalidRequestError, diff --git a/test/test_get_table.py b/test/test_get_table.py index 8100b66..7f49f84 100644 --- a/test/test_get_table.py +++ b/test/test_get_table.py @@ -37,21 +37,21 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_returns_datatable_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_table('ZACKS/FC', params={}) self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_with_code_returns_datatable_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_table('AR/MWCF', code="ICEP_WAC_Z2017_S") self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_no_params_for_get_request(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_table('ZACKS/FC') @@ -59,7 +59,7 @@ def test_get_table_calls_connection_with_no_params_for_get_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_no_params_for_post_request(self, mock): with self.assertWarns(UserWarning): RequestType.USE_GET_REQUEST = False @@ -69,7 +69,7 @@ def test_get_table_calls_connection_with_no_params_for_post_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_params_for_get_request(self, mock): with self.assertWarns(UserWarning): params = { @@ -93,7 +93,7 @@ def test_get_table_calls_connection_with_params_for_get_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_params_for_post_request(self, mock): with self.assertWarns(UserWarning): RequestType.USE_GET_REQUEST = False diff --git a/test/test_point_in_time.py b/test/test_point_in_time.py index 07918e1..f73e33b 100644 --- a/test/test_point_in_time.py +++ b/test/test_point_in_time.py @@ -26,7 +26,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection(self, mock): PointInTime( 'ZACKS/FC', @@ -38,7 +38,7 @@ def test_asofdate_call_connection(self, mock): expected = call('get', 'pit/ZACKS/FC/asofdate/2020-01-01', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection(self, mock): PointInTime( 'ZACKS/FC', @@ -51,7 +51,7 @@ def test_from_call_connection(self, mock): expected = call('get', 'pit/ZACKS/FC/from/2020-01-01/to/2020-01-02', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection(self, mock): PointInTime( 'ZACKS/FC', diff --git a/test/test_retries.py b/test/test_retries.py index 3028095..69c7653 100644 --- a/test/test_retries.py +++ b/test/test_retries.py @@ -1,7 +1,7 @@ import unittest import json -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.api_config import ApiConfig from test.factories.datatable import DatatableFactory from test.helpers.httpretty_extension import httpretty @@ -28,6 +28,8 @@ def tearDown(self): class TestRetries(ModifyRetrySettingsTestCase): def setUp(self): + # reset session to None before every test + connection.session = None ApiConfig.use_retries = True super(TestRetries, self).setUp() @@ -47,13 +49,13 @@ def setUpClass(cls): def test_modifying_use_retries(self): ApiConfig.use_retries = False - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.total, 0) def test_modifying_number_of_retries(self): ApiConfig.number_of_retries = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.total, ApiConfig.number_of_retries) self.assertEqual(retries.connect, ApiConfig.number_of_retries) @@ -62,19 +64,19 @@ def test_modifying_number_of_retries(self): def test_modifying_retry_backoff_factor(self): ApiConfig.retry_backoff_factor = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.backoff_factor, ApiConfig.retry_backoff_factor) def test_modifying_retry_status_codes(self): ApiConfig.retry_status_codes = [1, 2, 3] - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.status_forcelist, ApiConfig.retry_status_codes) def test_modifying_max_wait_between_retries(self): ApiConfig.max_wait_between_retries = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.BACKOFF_MAX, ApiConfig.max_wait_between_retries) @httpretty.enabled @@ -87,7 +89,7 @@ def test_correct_response_returned_if_retries_succeed(self): "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - response = Connection.request('get', 'databases') + response = connection.request('get', 'databases') self.assertEqual(response.json(), self.datatable) self.assertEqual(response.status_code, self.success_response.status) @@ -100,7 +102,7 @@ def test_correct_response_exception_raised_if_retries_fail(self): "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - self.assertRaises(InternalServerError, Connection.request, 'get', 'databases') + self.assertRaises(InternalServerError, connection.request, 'get', 'databases') @httpretty.enabled def test_correct_response_exception_raised_for_errors_not_in_retry_status_codes(self): @@ -110,4 +112,4 @@ def test_correct_response_exception_raised_for_errors_not_in_retry_status_codes( "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - self.assertRaises(InternalServerError, Connection.request, 'get', 'databases') + self.assertRaises(InternalServerError, connection.request, 'get', 'databases') pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy