From e661698eb6db9b07416655ea4b6df6583b9fbfec Mon Sep 17 00:00:00 2001 From: Ren Ren Date: Wed, 30 Mar 2022 15:49:53 -0400 Subject: [PATCH 1/8] change connection from class to module --- nasdaqdatalink/connection.py | 192 ++++++++++++++-------------- nasdaqdatalink/model/database.py | 2 +- nasdaqdatalink/model/datatable.py | 2 +- nasdaqdatalink/operations/get.py | 2 +- nasdaqdatalink/operations/list.py | 2 +- test/test_connection.py | 4 +- test/test_data.py | 2 +- test/test_database.py | 6 +- test/test_dataset.py | 4 +- test/test_datatable.py | 10 +- test/test_datatable_data.py | 4 +- test/test_get.py | 4 +- test/test_get_point_in_time_data.py | 22 ++-- test/test_get_table.py | 12 +- test/test_point_in_time.py | 6 +- test/test_retries.py | 2 +- 16 files changed, 135 insertions(+), 141 deletions(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 2338a5f..73a6241 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -13,103 +13,97 @@ AuthenticationError, ForbiddenError, InvalidRequestError, NotFoundError, ServiceUnavailableError) - -class Connection: - @classmethod - def request(cls, http_verb, url, **options): - if 'headers' in options: - headers = options['headers'] +def request(http_verb, url, **options): + if 'headers' in options: + headers = options['headers'] + else: + headers = {} + + accept_value = 'application/json' + if ApiConfig.api_version: + accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version + + headers = Util.merge_to_dicts({'accept': accept_value, + 'request-source': 'python', + 'request-source-version': VERSION}, headers) + if ApiConfig.api_key: + headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) + + options['headers'] = headers + + abs_url = '%s/%s' % (ApiConfig.api_base, url) + + return execute_request(http_verb, abs_url, **options) + +def execute_request(http_verb, url, **options): + session = get_session(url) + + try: + response = session.request(method=http_verb, + url=url, + verify=ApiConfig.verify_ssl, + **options) + if response.status_code < 200 or response.status_code >= 300: + handle_api_error(response) else: - headers = {} - - accept_value = 'application/json' - if ApiConfig.api_version: - accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version - - headers = Util.merge_to_dicts({'accept': accept_value, - 'request-source': 'python', - 'request-source-version': VERSION}, headers) - if ApiConfig.api_key: - headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) - - options['headers'] = headers - - abs_url = '%s/%s' % (ApiConfig.api_base, url) - - return cls.execute_request(http_verb, abs_url, **options) - - @classmethod - def execute_request(cls, http_verb, url, **options): - session = cls.get_session() - - try: - response = session.request(method=http_verb, - url=url, - verify=ApiConfig.verify_ssl, - **options) - if response.status_code < 200 or response.status_code >= 300: - cls.handle_api_error(response) - else: - return response - except requests.exceptions.RequestException as e: - if e.response: - cls.handle_api_error(e.response) - raise e - - @classmethod - def get_session(cls): - session = requests.Session() - adapter = HTTPAdapter(max_retries=cls.get_retries()) - session.mount(ApiConfig.api_protocol, adapter) - - return session - - @classmethod - def get_retries(cls): - if not ApiConfig.use_retries: - return Retry(total=0) - - Retry.BACKOFF_MAX = ApiConfig.max_wait_between_retries - retries = Retry(total=ApiConfig.number_of_retries, - connect=ApiConfig.number_of_retries, - read=ApiConfig.number_of_retries, - status_forcelist=ApiConfig.retry_status_codes, - backoff_factor=ApiConfig.retry_backoff_factor, - raise_on_status=False) - - return retries - - @classmethod - def parse(cls, response): - try: - return response.json() - except ValueError: - raise DataLinkError(http_status=response.status_code, http_body=response.text) - - @classmethod - def handle_api_error(cls, resp): - error_body = cls.parse(resp) - - # if our app does not form a proper data_link_error response - # throw generic error - if 'error' not in error_body: - raise DataLinkError(http_status=resp.status_code, http_body=resp.text) - - code = error_body['error']['code'] - message = error_body['error']['message'] - prog = re.compile('^QE([a-zA-Z])x') - if prog.match(code): - code_letter = prog.match(code).group(1) - - d_klass = { - 'L': LimitExceededError, - 'M': InternalServerError, - 'A': AuthenticationError, - 'P': ForbiddenError, - 'S': InvalidRequestError, - 'C': NotFoundError, - 'X': ServiceUnavailableError - } - klass = d_klass.get(code_letter, DataLinkError) - - raise klass(message, resp.status_code, resp.text, resp.headers, code) + return response + except requests.exceptions.RequestException as e: + if e.response: + handle_api_error(e.response) + raise e + +def get_retries(): + if not ApiConfig.use_retries: + return Retry(total=0) + + Retry.BACKOFF_MAX = ApiConfig.max_wait_between_retries + retries = Retry(total=ApiConfig.number_of_retries, + connect=ApiConfig.number_of_retries, + read=ApiConfig.number_of_retries, + status_forcelist=ApiConfig.retry_status_codes, + backoff_factor=ApiConfig.retry_backoff_factor, + raise_on_status=False) + + return retries + +session = requests.Session() + +def get_session(url = ApiConfig.api_protocol): + adapter = HTTPAdapter(max_retries=get_retries()) + session.mount(url, adapter) + return session + +def parse(response): + try: + return response.json() + except ValueError: + raise DataLinkError(http_status=response.status_code, http_body=response.text) + + + +def handle_api_error(resp): + error_body = parse(resp) + + # if our app does not form a proper data_link_error response + # throw generic error + if 'error' not in error_body: + raise DataLinkError(http_status=resp.status_code, http_body=resp.text) + + code = error_body['error']['code'] + message = error_body['error']['message'] + prog = re.compile('^QE([a-zA-Z])x') + if prog.match(code): + code_letter = prog.match(code).group(1) + + d_klass = { + 'L': LimitExceededError, + 'M': InternalServerError, + 'A': AuthenticationError, + 'P': ForbiddenError, + 'S': InvalidRequestError, + 'C': NotFoundError, + 'X': ServiceUnavailableError + } + klass = d_klass.get(code_letter, DataLinkError) + + raise klass(message, resp.status_code, resp.text, resp.headers, code) \ No newline at end of file diff --git a/nasdaqdatalink/model/database.py b/nasdaqdatalink/model/database.py index 870dedc..5cde79a 100644 --- a/nasdaqdatalink/model/database.py +++ b/nasdaqdatalink/model/database.py @@ -4,7 +4,7 @@ import nasdaqdatalink.model.dataset from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation diff --git a/nasdaqdatalink/model/datatable.py b/nasdaqdatalink/model/datatable.py index 2edadb8..b253764 100644 --- a/nasdaqdatalink/model/datatable.py +++ b/nasdaqdatalink/model/datatable.py @@ -3,7 +3,7 @@ from six.moves.urllib.request import urlopen -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation diff --git a/nasdaqdatalink/operations/get.py b/nasdaqdatalink/operations/get.py index 8f93b95..3d70a79 100644 --- a/nasdaqdatalink/operations/get.py +++ b/nasdaqdatalink/operations/get.py @@ -1,7 +1,7 @@ from inflection import singularize from .operation import Operation -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.util import Util diff --git a/nasdaqdatalink/operations/list.py b/nasdaqdatalink/operations/list.py index 6aa020a..fb2f5cd 100644 --- a/nasdaqdatalink/operations/list.py +++ b/nasdaqdatalink/operations/list.py @@ -1,5 +1,5 @@ from .operation import Operation -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.util import Util from nasdaqdatalink.model.paginated_list import PaginatedList from nasdaqdatalink.utils.request_type_util import RequestType diff --git a/test/test_connection.py b/test/test_connection.py index 96d8380..3ee62eb 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -1,4 +1,4 @@ -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.api_config import ApiConfig from nasdaqdatalink.errors.data_link_error import ( DataLinkError, LimitExceededError, InternalServerError, @@ -65,7 +65,7 @@ def test_non_data_link_error(self, request_method): DataLinkError, lambda: Connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) - @patch('nasdaqdatalink.connection.Connection.execute_request') + @patch('nasdaqdatalink.connection.execute_request') def test_build_request(self, request_method, mock): ApiConfig.api_key = 'api_token' ApiConfig.api_version = '2015-04-09' diff --git a/test/test_data.py b/test/test_data.py index 7852dbe..53817a1 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -77,7 +77,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection(self, mock): Data.all(params={'database_code': 'NSE', 'dataset_code': 'OIL'}) expected = call('get', 'datasets/NSE/OIL/data', params={}) diff --git a/test/test_database.py b/test/test_database.py index bbae558..0b11cec 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -7,7 +7,7 @@ from six.moves.urllib.parse import parse_qs, urlparse from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.errors.data_link_error import (InternalServerError, DataLinkError) from nasdaqdatalink.model.database import Database from test.factories.database import DatabaseFactory @@ -34,7 +34,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_database_calls_connection(self, mock): database = Database('NSE') database.data_fields() @@ -80,7 +80,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_databases_calls_connection(self, mock): Database.all() expected = call('get', 'databases', params={}) diff --git a/test/test_dataset.py b/test/test_dataset.py index c44ea65..aed9b8a 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -30,7 +30,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_dataset_calls_connection(self, mock): d = Dataset('NSE/OIL') d.data_fields() @@ -84,7 +84,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datasets_calls_connection(self, mock): Dataset.all() expected = call('get', 'datasets', params={}) diff --git a/test/test_datatable.py b/test/test_datatable.py index ab80194..ff5525b 100644 --- a/test/test_datatable.py +++ b/test/test_datatable.py @@ -37,26 +37,26 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_metadata_calls_connection(self, mock): Datatable('ZACKS/FC').data_fields() expected = call('get', 'datatables/ZACKS/FC/metadata', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_data_calls_connection_with_no_params_for_get_request(self, mock): Datatable('ZACKS/FC').data() expected = call('get', 'datatables/ZACKS/FC', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_data_calls_connection_with_no_params_for_post_request(self, mock): RequestType.USE_GET_REQUEST = False Datatable('ZACKS/FC').data() expected = call('post', 'datatables/ZACKS/FC', json={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_calls_connection_with_params_for_get_request(self, mock): params = {'ticker': ['AAPL', 'MSFT'], 'per_end_date': {'gte': '2015-01-01'}, @@ -76,7 +76,7 @@ def test_datatable_calls_connection_with_params_for_get_request(self, mock): expected = call('get', 'datatables/ZACKS/FC', params=expected_params) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_calls_connection_with_params_for_post_request(self, mock): RequestType.USE_GET_REQUEST = False params = {'ticker': ['AAPL', 'MSFT'], diff --git a/test/test_datatable_data.py b/test/test_datatable_data.py index 7ba53f3..b837fc9 100644 --- a/test/test_datatable_data.py +++ b/test/test_datatable_data.py @@ -83,7 +83,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection_get(self, mock): datatable = Datatable('ZACKS/FC') Data.page(datatable, params={'ticker': ['AAPL', 'MSFT'], @@ -95,7 +95,7 @@ def test_data_calls_connection_get(self, mock): 'qopts.columns[]': ['ticker', 'per_end_date']}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection_post(self, mock): RequestType.USE_GET_REQUEST = False datatable = Datatable('ZACKS/FC') diff --git a/test/test_get.py b/test/test_get.py index 950c5c5..8e753c7 100644 --- a/test/test_get.py +++ b/test/test_get.py @@ -8,7 +8,7 @@ from nasdaqdatalink.model.merged_dataset import MergedDataset from nasdaqdatalink.get import get from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection class GetSingleDatasetTest(unittest.TestCase): @@ -37,7 +37,7 @@ def test_returns_numpys_when_requested(self): def test_setting_api_key_config(self): mock_connection = Mock(wraps=Connection) - with patch('nasdaqdatalink.connection.Connection.execute_request', + with patch('nasdaqdatalink.connection.execute_request', new=mock_connection.execute_request) as mock: ApiConfig.api_key = 'api_key_configured' get('NSE/OIL') diff --git a/test/test_get_point_in_time_data.py b/test/test_get_point_in_time_data.py index 8fa57f7..ef0b236 100644 --- a/test/test_get_point_in_time_data.py +++ b/test/test_get_point_in_time_data.py @@ -27,7 +27,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_point_in_time_returns_data_frame_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_point_in_time( @@ -36,7 +36,7 @@ def test_get_point_in_time_returns_data_frame_object(self, mock): self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='asofdate', date='2020-01-01') @@ -44,7 +44,7 @@ def test_asofdate_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -54,7 +54,7 @@ def test_asofdate_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_without_date(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='asofdate') @@ -62,7 +62,7 @@ def test_asofdate_call_without_date(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -75,7 +75,7 @@ def test_from_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -90,7 +90,7 @@ def test_from_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -103,7 +103,7 @@ def test_between_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -118,7 +118,7 @@ def test_between_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_interval_connection(self, mock): self.assertRaises(InvalidRequestError, lambda: nasdaqdatalink.get_point_in_time('ZACKS/FC')) self.assertRaises( @@ -126,7 +126,7 @@ def test_invalid_interval_connection(self, mock): lambda: nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='nasdaqdatalink') ) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_from_connection(self, mock): self.assertRaises( InvalidRequestError, @@ -145,7 +145,7 @@ def test_invalid_from_connection(self, mock): ) ) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_between_connection(self, mock): self.assertRaises( InvalidRequestError, diff --git a/test/test_get_table.py b/test/test_get_table.py index 8100b66..7f49f84 100644 --- a/test/test_get_table.py +++ b/test/test_get_table.py @@ -37,21 +37,21 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_returns_datatable_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_table('ZACKS/FC', params={}) self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_with_code_returns_datatable_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_table('AR/MWCF', code="ICEP_WAC_Z2017_S") self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_no_params_for_get_request(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_table('ZACKS/FC') @@ -59,7 +59,7 @@ def test_get_table_calls_connection_with_no_params_for_get_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_no_params_for_post_request(self, mock): with self.assertWarns(UserWarning): RequestType.USE_GET_REQUEST = False @@ -69,7 +69,7 @@ def test_get_table_calls_connection_with_no_params_for_post_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_params_for_get_request(self, mock): with self.assertWarns(UserWarning): params = { @@ -93,7 +93,7 @@ def test_get_table_calls_connection_with_params_for_get_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_params_for_post_request(self, mock): with self.assertWarns(UserWarning): RequestType.USE_GET_REQUEST = False diff --git a/test/test_point_in_time.py b/test/test_point_in_time.py index 07918e1..f73e33b 100644 --- a/test/test_point_in_time.py +++ b/test/test_point_in_time.py @@ -26,7 +26,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection(self, mock): PointInTime( 'ZACKS/FC', @@ -38,7 +38,7 @@ def test_asofdate_call_connection(self, mock): expected = call('get', 'pit/ZACKS/FC/asofdate/2020-01-01', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection(self, mock): PointInTime( 'ZACKS/FC', @@ -51,7 +51,7 @@ def test_from_call_connection(self, mock): expected = call('get', 'pit/ZACKS/FC/from/2020-01-01/to/2020-01-02', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection(self, mock): PointInTime( 'ZACKS/FC', diff --git a/test/test_retries.py b/test/test_retries.py index 3028095..857580e 100644 --- a/test/test_retries.py +++ b/test/test_retries.py @@ -1,7 +1,7 @@ import unittest import json -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.api_config import ApiConfig from test.factories.datatable import DatatableFactory from test.helpers.httpretty_extension import httpretty From 24e8845c514f1989fba863e7fed4b0ba7dd44e57 Mon Sep 17 00:00:00 2001 From: Ren Ren Date: Fri, 1 Apr 2022 13:20:27 -0400 Subject: [PATCH 2/8] reuse adapter --- nasdaqdatalink/connection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 73a6241..72335f8 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -36,7 +36,7 @@ def request(http_verb, url, **options): return execute_request(http_verb, abs_url, **options) def execute_request(http_verb, url, **options): - session = get_session(url) + session = get_session() try: response = session.request(method=http_verb, @@ -67,10 +67,10 @@ def get_retries(): return retries session = requests.Session() +adapter = HTTPAdapter(max_retries=get_retries()) +session.mount(ApiConfig.api_protocol, adapter) -def get_session(url = ApiConfig.api_protocol): - adapter = HTTPAdapter(max_retries=get_retries()) - session.mount(url, adapter) +def get_session(): return session def parse(response): From c258098f3093f8d587c56a72909092f15a78c9ed Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 11:22:27 -0400 Subject: [PATCH 3/8] change initialization flow to allow configurate --- nasdaqdatalink/connection.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 72335f8..94b80f0 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -66,11 +66,14 @@ def get_retries(): return retries -session = requests.Session() -adapter = HTTPAdapter(max_retries=get_retries()) -session.mount(ApiConfig.api_protocol, adapter) +session = None def get_session(): + global session + if session is None: + session = requests.Session() + adapter = HTTPAdapter(max_retries=get_retries()) + session.mount(ApiConfig.api_protocol, adapter) return session def parse(response): From a715192f6d8ff58e7c15948061c33d34dcd55cad Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 15:36:20 -0400 Subject: [PATCH 4/8] change module naming, fix test --- nasdaqdatalink/model/database.py | 4 ++-- nasdaqdatalink/model/datatable.py | 4 ++-- nasdaqdatalink/operations/get.py | 4 ++-- nasdaqdatalink/operations/list.py | 6 +++--- test/test_connection.py | 10 +++++----- test/test_database.py | 4 ++-- test/test_get.py | 4 ++-- test/test_retries.py | 20 +++++++++++--------- 8 files changed, 29 insertions(+), 27 deletions(-) diff --git a/nasdaqdatalink/model/database.py b/nasdaqdatalink/model/database.py index 5cde79a..fbf9e73 100644 --- a/nasdaqdatalink/model/database.py +++ b/nasdaqdatalink/model/database.py @@ -4,7 +4,7 @@ import nasdaqdatalink.model.dataset from nasdaqdatalink.api_config import ApiConfig -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation @@ -43,7 +43,7 @@ def bulk_download_to_file(self, file_or_folder_path, **options): path_url = self._bulk_download_path() options['stream'] = True - r = Connection.request('get', path_url, **options) + r = connection.request('get', path_url, **options) file_path = file_or_folder_path if os.path.isdir(file_or_folder_path): file_path = file_or_folder_path + '/' + os.path.basename(urlparse(r.url).path) diff --git a/nasdaqdatalink/model/datatable.py b/nasdaqdatalink/model/datatable.py index b253764..935590e 100644 --- a/nasdaqdatalink/model/datatable.py +++ b/nasdaqdatalink/model/datatable.py @@ -3,7 +3,7 @@ from six.moves.urllib.request import urlopen -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation @@ -51,7 +51,7 @@ def _request_file_info(self, file_or_folder_path, **options): updated_options = Util.convert_options(request_type=request_type, **options) - r = Connection.request(request_type, url, **updated_options) + r = connection.request(request_type, url, **updated_options) response_data = r.json() diff --git a/nasdaqdatalink/operations/get.py b/nasdaqdatalink/operations/get.py index 3d70a79..efe3a26 100644 --- a/nasdaqdatalink/operations/get.py +++ b/nasdaqdatalink/operations/get.py @@ -1,7 +1,7 @@ from inflection import singularize from .operation import Operation -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.util import Util @@ -21,7 +21,7 @@ def __get_raw_data__(self): path = Util.constructed_path(cls.get_path(), options['params']) - r = Connection.request('get', path, **options) + r = connection.request('get', path, **options) response_data = r.json() Util.convert_to_dates(response_data) self._raw_data = response_data[singularize(cls.lookup_key())] diff --git a/nasdaqdatalink/operations/list.py b/nasdaqdatalink/operations/list.py index fb2f5cd..6e94e78 100644 --- a/nasdaqdatalink/operations/list.py +++ b/nasdaqdatalink/operations/list.py @@ -1,5 +1,5 @@ from .operation import Operation -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.util import Util from nasdaqdatalink.model.paginated_list import PaginatedList from nasdaqdatalink.utils.request_type_util import RequestType @@ -12,7 +12,7 @@ def all(cls, **options): if 'params' not in options: options['params'] = {} path = Util.constructed_path(cls.list_path(), options['params']) - r = Connection.request('get', path, **options) + r = connection.request('get', path, **options) response_data = r.json() Util.convert_to_dates(response_data) resource = cls.create_list_from_response(response_data) @@ -27,7 +27,7 @@ def page(cls, datatable, **options): updated_options = Util.convert_options(request_type=request_type, **options) - r = Connection.request(request_type, path, **updated_options) + r = connection.request(request_type, path, **updated_options) response_data = r.json() Util.convert_to_dates(response_data) diff --git a/test/test_connection.py b/test/test_connection.py index 3ee62eb..7cf1845 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -1,4 +1,4 @@ -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.api_config import ApiConfig from nasdaqdatalink.errors.data_link_error import ( DataLinkError, LimitExceededError, InternalServerError, @@ -42,7 +42,7 @@ def test_nasdaqdatalink_exceptions_no_retries(self, request_method): for expected_error in data_link_errors: self.assertRaises( - expected_error[2], lambda: Connection.request(request_method, 'databases')) + expected_error[2], lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) def test_parse_error(self, request_method): @@ -51,7 +51,7 @@ def test_parse_error(self, request_method): "https://data.nasdaq.com/api/v3/databases", body="not json", status=500) self.assertRaises( - DataLinkError, lambda: Connection.request(request_method, 'databases')) + DataLinkError, lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) def test_non_data_link_error(self, request_method): @@ -62,7 +62,7 @@ def test_non_data_link_error(self, request_method): {'foobar': {'code': 'blah', 'message': 'something went wrong'}}), status=500) self.assertRaises( - DataLinkError, lambda: Connection.request(request_method, 'databases')) + DataLinkError, lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) @patch('nasdaqdatalink.connection.execute_request') @@ -71,7 +71,7 @@ def test_build_request(self, request_method, mock): ApiConfig.api_version = '2015-04-09' params = {'per_page': 10, 'page': 2} headers = {'x-custom-header': 'header value'} - Connection.request(request_method, 'databases', headers=headers, params=params) + connection.request(request_method, 'databases', headers=headers, params=params) expected = call(request_method, 'https://data.nasdaq.com/api/v3/databases', headers={'x-custom-header': 'header value', 'x-api-token': 'api_token', diff --git a/test/test_database.py b/test/test_database.py index 0b11cec..7b38b98 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -7,7 +7,7 @@ from six.moves.urllib.parse import parse_qs, urlparse from nasdaqdatalink.api_config import ApiConfig -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import (InternalServerError, DataLinkError) from nasdaqdatalink.model.database import Database from test.factories.database import DatabaseFactory @@ -148,7 +148,7 @@ def test_get_bulk_download_url_without_download_type(self): def test_bulk_download_to_fileaccepts_download_type(self): m = mock_open() - with patch.object(Connection, 'request') as mock_method: + with patch.object(connection, 'request') as mock_method: mock_method.return_value.url = 'https://www.blah.com/download/db.zip' with patch('nasdaqdatalink.model.database.open', m, create=True): self.database.bulk_download_to_file( diff --git a/test/test_get.py b/test/test_get.py index 8e753c7..66f2ba6 100644 --- a/test/test_get.py +++ b/test/test_get.py @@ -8,7 +8,7 @@ from nasdaqdatalink.model.merged_dataset import MergedDataset from nasdaqdatalink.get import get from nasdaqdatalink.api_config import ApiConfig -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection class GetSingleDatasetTest(unittest.TestCase): @@ -36,7 +36,7 @@ def test_returns_numpys_when_requested(self): self.assertIsInstance(result, numpy.core.records.recarray) def test_setting_api_key_config(self): - mock_connection = Mock(wraps=Connection) + mock_connection = Mock(wraps=connection) with patch('nasdaqdatalink.connection.execute_request', new=mock_connection.execute_request) as mock: ApiConfig.api_key = 'api_key_configured' diff --git a/test/test_retries.py b/test/test_retries.py index 857580e..69c7653 100644 --- a/test/test_retries.py +++ b/test/test_retries.py @@ -1,7 +1,7 @@ import unittest import json -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.api_config import ApiConfig from test.factories.datatable import DatatableFactory from test.helpers.httpretty_extension import httpretty @@ -28,6 +28,8 @@ def tearDown(self): class TestRetries(ModifyRetrySettingsTestCase): def setUp(self): + # reset session to None before every test + connection.session = None ApiConfig.use_retries = True super(TestRetries, self).setUp() @@ -47,13 +49,13 @@ def setUpClass(cls): def test_modifying_use_retries(self): ApiConfig.use_retries = False - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.total, 0) def test_modifying_number_of_retries(self): ApiConfig.number_of_retries = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.total, ApiConfig.number_of_retries) self.assertEqual(retries.connect, ApiConfig.number_of_retries) @@ -62,19 +64,19 @@ def test_modifying_number_of_retries(self): def test_modifying_retry_backoff_factor(self): ApiConfig.retry_backoff_factor = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.backoff_factor, ApiConfig.retry_backoff_factor) def test_modifying_retry_status_codes(self): ApiConfig.retry_status_codes = [1, 2, 3] - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.status_forcelist, ApiConfig.retry_status_codes) def test_modifying_max_wait_between_retries(self): ApiConfig.max_wait_between_retries = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.BACKOFF_MAX, ApiConfig.max_wait_between_retries) @httpretty.enabled @@ -87,7 +89,7 @@ def test_correct_response_returned_if_retries_succeed(self): "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - response = Connection.request('get', 'databases') + response = connection.request('get', 'databases') self.assertEqual(response.json(), self.datatable) self.assertEqual(response.status_code, self.success_response.status) @@ -100,7 +102,7 @@ def test_correct_response_exception_raised_if_retries_fail(self): "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - self.assertRaises(InternalServerError, Connection.request, 'get', 'databases') + self.assertRaises(InternalServerError, connection.request, 'get', 'databases') @httpretty.enabled def test_correct_response_exception_raised_for_errors_not_in_retry_status_codes(self): @@ -110,4 +112,4 @@ def test_correct_response_exception_raised_for_errors_not_in_retry_status_codes( "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - self.assertRaises(InternalServerError, Connection.request, 'get', 'databases') + self.assertRaises(InternalServerError, connection.request, 'get', 'databases') From d9105388e72621ae50498bb884b49b4dc2a4aabe Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 16:34:53 -0400 Subject: [PATCH 5/8] fix lint --- nasdaqdatalink/connection.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 94b80f0..ffd7d4a 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -13,6 +13,10 @@ AuthenticationError, ForbiddenError, InvalidRequestError, NotFoundError, ServiceUnavailableError) +# global session +session = None + + def request(http_verb, url, **options): if 'headers' in options: headers = options['headers'] @@ -23,9 +27,11 @@ def request(http_verb, url, **options): if ApiConfig.api_version: accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version - headers = Util.merge_to_dicts({'accept': accept_value, - 'request-source': 'python', - 'request-source-version': VERSION}, headers) + headers = Util.merge_to_dicts({ + 'accept': accept_value, + 'request-source': 'python', + 'request-source-version': VERSION + }, headers) if ApiConfig.api_key: headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) @@ -35,14 +41,17 @@ def request(http_verb, url, **options): return execute_request(http_verb, abs_url, **options) + def execute_request(http_verb, url, **options): session = get_session() try: - response = session.request(method=http_verb, - url=url, - verify=ApiConfig.verify_ssl, - **options) + response = session.request( + method=http_verb, + url=url, + verify=ApiConfig.verify_ssl, + **options + ) if response.status_code < 200 or response.status_code >= 300: handle_api_error(response) else: @@ -52,6 +61,7 @@ def execute_request(http_verb, url, **options): handle_api_error(e.response) raise e + def get_retries(): if not ApiConfig.use_retries: return Retry(total=0) @@ -66,16 +76,17 @@ def get_retries(): return retries -session = None def get_session(): global session if session is None: + print("initialized") session = requests.Session() adapter = HTTPAdapter(max_retries=get_retries()) session.mount(ApiConfig.api_protocol, adapter) return session + def parse(response): try: return response.json() @@ -83,7 +94,6 @@ def parse(response): raise DataLinkError(http_status=response.status_code, http_body=response.text) - def handle_api_error(resp): error_body = parse(resp) @@ -109,4 +119,4 @@ def handle_api_error(resp): } klass = d_klass.get(code_letter, DataLinkError) - raise klass(message, resp.status_code, resp.text, resp.headers, code) \ No newline at end of file + raise klass(message, resp.status_code, resp.text, resp.headers, code) From c79d4501c6e436a2cfbd84a9669206b20125155d Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 16:38:49 -0400 Subject: [PATCH 6/8] remove debug print --- nasdaqdatalink/connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index ffd7d4a..718db1e 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -80,7 +80,6 @@ def get_retries(): def get_session(): global session if session is None: - print("initialized") session = requests.Session() adapter = HTTPAdapter(max_retries=get_retries()) session.mount(ApiConfig.api_protocol, adapter) From 2b0f25d955a2b998352f060a36be62e50e7cf4bd Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 22:14:52 -0400 Subject: [PATCH 7/8] add test for session reuse --- test/test_connection.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/test_connection.py b/test/test_connection.py index 7cf1845..c5f75e9 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -81,3 +81,15 @@ def test_build_request(self, request_method, mock): 'request-source-version': VERSION}, params={'per_page': 10, 'page': 2}) self.assertEqual(mock.call_args, expected) + + def test_session_reuse(self): + session1 = connection.get_session() + session2 = connection.get_session() + areSessionsSame = session1 is session2 + + adapter1 = connection.get_session().get_adapter(ApiConfig.api_protocol) + adapter2 = connection.get_session().get_adapter(ApiConfig.api_protocol) + areAdaptersSame = adapter1 is adapter2 + + self.assertEqual(areAdaptersSame, True) + self.assertEqual(areSessionsSame, True) From 9bc8dd0729882daddb478e08d7454dcd5f8eafa4 Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 22:17:50 -0400 Subject: [PATCH 8/8] fix lint --- test/test_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_connection.py b/test/test_connection.py index c5f75e9..384d1e0 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -90,6 +90,6 @@ def test_session_reuse(self): adapter1 = connection.get_session().get_adapter(ApiConfig.api_protocol) adapter2 = connection.get_session().get_adapter(ApiConfig.api_protocol) areAdaptersSame = adapter1 is adapter2 - + self.assertEqual(areAdaptersSame, True) self.assertEqual(areSessionsSame, True) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy