Skip to content

Commit

Permalink
Record usage of .save (#1146)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardm-stripe authored Dec 13, 2023
1 parent bf78301 commit 4464b6a
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 40 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,11 @@ stripe.set_app_info("MyAwesomePlugin", version="1.2.34", url="https://myawesomep
This information is passed along when the library makes calls to the Stripe
API.
### Request latency telemetry
### Telemetry
By default, the library sends request latency telemetry to Stripe. These
numbers help Stripe improve the overall latency of its API for all users.
By default, the library sends telemetry to Stripe regarding request latency and feature usage. These
numbers help Stripe improve the overall latency of its API for all users, and
improve popular features.
You can disable this behavior if you prefer:
Expand Down
25 changes: 21 additions & 4 deletions stripe/_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
Any,
Dict,
List,
Mapping,
Optional,
Tuple,
Expand Down Expand Up @@ -99,9 +100,16 @@ def request(
url: str,
params: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, str]] = None,
*,
_usage: Optional[List[str]] = None,
) -> Tuple[StripeResponse, str]:
rbody, rcode, rheaders, my_api_key = self.request_raw(
method.lower(), url, params, headers, is_streaming=False
method.lower(),
url,
params,
headers,
is_streaming=False,
_usage=_usage,
)
resp = self.interpret_response(rbody, rcode, rheaders)
return resp, my_api_key
Expand All @@ -112,9 +120,16 @@ def request_stream(
url: str,
params: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, str]] = None,
*,
_usage: Optional[List[str]] = None,
) -> Tuple[StripeStreamResponse, str]:
stream, rcode, rheaders, my_api_key = self.request_raw(
method.lower(), url, params, headers, is_streaming=True
method.lower(),
url,
params,
headers,
is_streaming=True,
_usage=_usage,
)
resp = self.interpret_streaming_response(
# TODO: should be able to remove this cast once self._client.request_stream_with_retries
Expand Down Expand Up @@ -282,6 +297,8 @@ def request_raw(
params: Optional[Mapping[str, Any]] = None,
supplied_headers: Optional[Mapping[str, str]] = None,
is_streaming: bool = False,
*,
_usage: Optional[List[str]] = None,
) -> Tuple[object, int, Mapping[str, str], str]:
"""
Mechanism for issuing an API call
Expand Down Expand Up @@ -359,11 +376,11 @@ def request_raw(
rcode,
rheaders,
) = self._client.request_stream_with_retries(
method, abs_url, headers, post_data
method, abs_url, headers, post_data, _usage=_usage
)
else:
rcontent, rcode, rheaders = self._client.request_with_retries(
method, abs_url, headers, post_data
method, abs_url, headers, post_data, _usage=_usage
)

_util.log_info(
Expand Down
3 changes: 3 additions & 0 deletions stripe/_api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ClassVar,
Dict,
Generic,
List,
Optional,
TypeVar,
cast,
Expand Down Expand Up @@ -110,6 +111,7 @@ def _request_and_refresh(
stripe_account: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
params: Optional[Mapping[str, Any]] = None,
_usage: Optional[List[str]] = None,
) -> Self:
obj = StripeObject._request(
self,
Expand All @@ -121,6 +123,7 @@ def _request_and_refresh(
stripe_account,
headers,
params,
_usage=_usage,
)

self.refresh_from(obj)
Expand Down
32 changes: 23 additions & 9 deletions stripe/_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from stripe._request_metrics import RequestMetrics
from stripe._error import APIConnectionError

from typing import Any, Dict, Optional, Tuple, ClassVar, Union, cast
from typing import Any, Dict, List, Optional, Tuple, ClassVar, Union, cast
from typing_extensions import NoReturn, TypedDict

# - Requests is the preferred HTTP library
Expand Down Expand Up @@ -137,21 +137,33 @@ def __init__(

# TODO: more specific types here would be helpful
def request_with_retries(
self, method, url, headers, post_data=None
self,
method,
url,
headers,
post_data=None,
*,
_usage: Optional[List[str]] = None,
) -> Tuple[Any, int, Any]:
return self._request_with_retries_internal(
method, url, headers, post_data, is_streaming=False
method, url, headers, post_data, is_streaming=False, _usage=_usage
)

def request_stream_with_retries(
self, method, url, headers, post_data=None
self,
method,
url,
headers,
post_data=None,
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Any]:
return self._request_with_retries_internal(
method, url, headers, post_data, is_streaming=True
method, url, headers, post_data, is_streaming=True, _usage=_usage
)

def _request_with_retries_internal(
self, method, url, headers, post_data, is_streaming
self, method, url, headers, post_data, is_streaming, *, _usage=None
):
self._add_telemetry_header(headers)

Expand Down Expand Up @@ -190,7 +202,9 @@ def _request_with_retries_internal(
time.sleep(sleep_time)
else:
if response is not None:
self._record_request_metrics(response, request_start)
self._record_request_metrics(
response, request_start, _usage
)

return response
else:
Expand Down Expand Up @@ -297,13 +311,13 @@ def _add_telemetry_header(self, headers):
}
headers["X-Stripe-Client-Telemetry"] = json.dumps(telemetry)

def _record_request_metrics(self, response, request_start):
def _record_request_metrics(self, response, request_start, usage):
_, _, rheaders = response
if "Request-Id" in rheaders and stripe.enable_telemetry:
request_id = rheaders["Request-Id"]
request_duration_ms = _now_ms() - request_start
self._thread_local.last_request_metrics = RequestMetrics(
request_id, request_duration_ms
request_id, request_duration_ms, usage=usage
)

def close(self):
Expand Down
17 changes: 15 additions & 2 deletions stripe/_request_metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from typing import List, Optional


class RequestMetrics(object):
def __init__(self, request_id, request_duration_ms):
def __init__(
self,
request_id,
request_duration_ms,
usage: Optional[List[str]] = [],
):
self.request_id = request_id
self.request_duration_ms = request_duration_ms
self.usage = usage

def payload(self):
return {
ret = {
"request_id": self.request_id,
"request_duration_ms": self.request_duration_ms,
}

if self.usage is not None and len(self.usage) > 0:
ret["usage"] = self.usage
return ret
5 changes: 4 additions & 1 deletion stripe/_stripe_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def _request(
stripe_account: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
params: Optional[Mapping[str, Any]] = None,
_usage: Optional[List[str]] = None,
) -> "StripeObject":
params = None if params is None else dict(params)
api_key = _util.read_special_variable(params, "api_key", api_key)
Expand Down Expand Up @@ -377,7 +378,9 @@ def _request(
headers = {} if headers is None else headers.copy()
headers.update(_util.populate_headers(idempotency_key))

response, api_key = requestor.request(method_, url_, params, headers)
response, api_key = requestor.request(
method_, url_, params, headers, _usage=_usage
)

return _util.convert_to_stripe_object(
response, api_key, stripe_version, stripe_account, params
Expand Down
1 change: 1 addition & 0 deletions stripe/_updateable_api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def save(self, idempotency_key=None):
self.instance_url(),
idempotency_key=idempotency_key,
params=updated_params,
_usage=["save"],
)
else:
_util.logger.debug("Trying to save already saved object %r", self)
Expand Down
34 changes: 22 additions & 12 deletions tests/request_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,26 @@ def assert_api_version(self, expected_api_version):
)
raise AssertionError(msg)

def assert_requested(self, method, url, params=None, headers=None):
def assert_requested(
self, method, url, params=None, headers=None, _usage=None
):
self.assert_requested_internal(
self.request_patcher, method, url, params, headers
self.request_patcher, method, url, params, headers, _usage
)

def assert_requested_stream(self, method, url, params=None, headers=None):
def assert_requested_stream(
self, method, url, params=None, headers=None, _usage=None
):
self.assert_requested_internal(
self.request_stream_patcher, method, url, params, headers
self.request_stream_patcher, method, url, params, headers, _usage
)

def assert_requested_internal(self, patcher, method, url, params, headers):
def assert_requested_internal(
self, patcher, method, url, params, headers, usage
):
params = params or self._mocker.ANY
headers = headers or self._mocker.ANY
usage = usage or self._mocker.ANY
called = False
exception = None

Expand All @@ -134,14 +141,17 @@ def assert_requested_internal(self, patcher, method, url, params, headers):
(self._mocker.ANY, method, url, params, headers),
]

possible_called_kwargs = [{}, {"_usage": usage}]

for args in possible_called_args:
try:
patcher.assert_called_with(*args)
except AssertionError as e:
exception = e
else:
called = True
break
for kwargs in possible_called_kwargs:
try:
patcher.assert_called_with(*args, **kwargs)
except AssertionError as e:
exception = e
else:
called = True
break

if not called:
raise exception
Expand Down
5 changes: 3 additions & 2 deletions tests/test_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,11 @@ def check_call(

if is_streaming:
http_client.request_stream_with_retries.assert_called_with(
method, abs_url, headers, post_data
method, abs_url, headers, post_data, _usage=None
)
else:
http_client.request_with_retries.assert_called_with(
method, abs_url, headers, post_data
method, abs_url, headers, post_data, _usage=None
)

return check_call
Expand Down Expand Up @@ -797,4 +797,5 @@ def test_default_http_client_called(self, mocker):
"https://api.stripe.com/v1/charges?limit=3",
mocker.ANY,
None,
_usage=None,
)
26 changes: 19 additions & 7 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,6 @@ class MockServerRequestHandler(TestHandler):

def test_passes_client_telemetry_when_enabled(self):
class MockServerRequestHandler(TestHandler):
num_requests = 0

def do_request(self, req_num):
if req_num == 0:
time.sleep(31 / 1000) # 31 ms
Expand All @@ -174,7 +172,7 @@ def do_request(self, req_num):
200,
{
"Content-Type": "application/json; charset=utf-8",
"Request-Id": "req_1",
"Request-Id": "req_%s" % (req_num + 1),
},
None,
]
Expand All @@ -183,11 +181,15 @@ def do_request(self, req_num):
stripe.api_base = "http://localhost:%s" % self.mock_server_port
stripe.enable_telemetry = True

stripe.Balance.retrieve()
stripe.Balance.retrieve()
cus = stripe.Customer("cus_xyz")
cus.description = "hello"
cus.save()

stripe.Customer.retrieve("cus_xyz")
stripe.Customer.retrieve("cus_xyz")

reqs = MockServerRequestHandler.get_requests(3)

reqs = MockServerRequestHandler.get_requests(2)
assert MockServerRequestHandler.num_requests == 2
# req 1
assert not reqs[0].headers.get("x-stripe-client-telemetry")
# req 2
Expand All @@ -202,6 +204,16 @@ def do_request(self, req_num):
# latency shouldn't be outside this range.
assert 30 < duration_ms < 300

usage = telemetry["last_request_metrics"]["usage"]
assert usage == ["save"]

# req 3
telemetry_raw = reqs[2].headers.get("x-stripe-client-telemetry")
assert telemetry_raw is not None
metrics = json.loads(telemetry_raw)["last_request_metrics"]
assert metrics["request_id"] == "req_2"
assert "usage" not in metrics

def test_uses_thread_local_client_telemetry(self):
class MockServerRequestHandler(TestHandler):
local_num_requests = 0
Expand Down

0 comments on commit 4464b6a

Please sign in to comment.