Skip to content

Commit

Permalink
Merge pull request #36 from alexgolec/client-tests
Browse files Browse the repository at this point in the history
Add tests for the HTTP client
  • Loading branch information
alexgolec authored May 5, 2024
2 parents 38c6770 + 538f1f4 commit 7cbb21d
Show file tree
Hide file tree
Showing 3 changed files with 588 additions and 878 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test:
pytest
python -m pytest tests/

fix:
autopep8 --in-place -r -a schwab
Expand Down
34 changes: 19 additions & 15 deletions schwab/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,16 @@ def get_accounts(self, *, fields=None):
##########################################################################
# Orders

def cancel_order(self, order_id, account_hash):
'''Cancel a specific order for a specific account'''
path = '/trader/v1/accounts/{}/orders/{}'.format(account_hash, order_id)
return self._delete_request(path)

def get_order(self, order_id, account_hash):
'''Get a specific order for a specific account by its order ID'''
path = '/trader/v1/accounts/{}/orders/{}'.format(account_hash, order_id)
return self._get_request(path, {})

def cancel_order(self, order_id, account_hash):
'''Cancel a specific order for a specific account'''
path = '/trader/v1/accounts/{}/orders/{}'.format(account_hash, order_id)
return self._delete_request(path)

class Order:
class Status(Enum):
'''Order statuses passed to :meth:`get_orders_for_account` and
Expand Down Expand Up @@ -266,8 +266,7 @@ def get_orders_for_all_linked_accounts(self,
max_results=None,
from_entered_datetime=None,
to_entered_datetime=None,
status=None,
statuses=None):
status=None):
'''Orders for all linked accounts. Optionally specify a single status on
which to filter.
Expand Down Expand Up @@ -425,13 +424,13 @@ def get_user_preferences(self):
'''Preferences for the logged in account, including all linked
accounts.'''
path = '/trader/v1/userPreference'
return self._get_request(path, ())
return self._get_request(path, {})


##########################################################################
# Quotes

class GetQuote:
class Quote:
class Fields(Enum):
QUOTE = 'quote'
FUNDAMENTAL = 'fundamental'
Expand All @@ -447,9 +446,9 @@ def get_quote(self, symbol, *, fields=None):
:param fields: Fields to request. If unset, return all available data.
i.e. all fields. See :class:`GetQuote.Field` for options.
'''
fields = self.convert_enum_iterable(fields, self.GetQuote.Fields)
fields = self.convert_enum_iterable(fields, self.Quote.Fields)
if fields:
params = {'fields': fields}
params = {'fields': ','.join(fields)}
else:
params = {}

Expand All @@ -471,9 +470,9 @@ def get_quotes(self, symbols, *, fields=None, indicative=None):
'symbols': ','.join(symbols)
}

fields = self.convert_enum_iterable(fields, self.GetQuote.Fields)
fields = self.convert_enum_iterable(fields, self.Quote.Fields)
if fields:
params['fields'] = fields
params['fields'] = ','.join(fields)

if indicative is not None:
if type(indicative) is not bool:
Expand Down Expand Up @@ -604,9 +603,9 @@ def get_option_chain(
strike_range, self.Options.StrikeRange)
option_type = self.convert_enum(option_type, self.Options.Type)
exp_month = self.convert_enum(exp_month, self.Options.ExpirationMonth)
entitlement = self.convert_enum(entitlement, self.Options.Entitlement)

params = {
'apikey': self.api_key,
'symbol': symbol,
}

Expand Down Expand Up @@ -640,6 +639,8 @@ def get_option_chain(
params['expMonth'] = exp_month
if option_type is not None:
params['optionType'] = option_type
if entitlement is not None:
params['entitlement'] = entitlement

path = '/marketdata/v1/chains'
return self._get_request(path, params)
Expand Down Expand Up @@ -1010,7 +1011,7 @@ def get_movers(self, index, *, sort_order=None, frequency=None):
if sort_order is not None:
params['sort'] = sort_order
if frequency is not None:
params['frequency'] = frequency
params['frequency'] = str(frequency)

return self._get_request(path, params)

Expand Down Expand Up @@ -1114,5 +1115,8 @@ def get_instrument_by_cusip(self, cusip):
:param cusip: String representing CUSIP of instrument for which to fetch
data. Note leading zeroes must be preserved.
'''
if not isinstance(cusip, str):
raise ValueError('cusip must be passed as str')

return self._get_request(
'/marketdata/v1/instruments/{}'.format(cusip), {})
Loading

0 comments on commit 7cbb21d

Please sign in to comment.