Skip to content

Commit

Permalink
Make total_records and total_bytes optional, support None
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Jul 26, 2024
1 parent 14dca9f commit 2c8bfa3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 28 deletions.
10 changes: 5 additions & 5 deletions python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ cdef class FlightInfo(_Weakrefable):
return obj

def __init__(self, Schema schema, FlightDescriptor descriptor, endpoints,
total_records, total_bytes, ordered=False, app_metadata=""):
total_records=None, total_bytes=None, ordered=False, app_metadata=""):
"""Create a FlightInfo object from a schema, descriptor, and endpoints.
Parameters
Expand All @@ -855,9 +855,9 @@ cdef class FlightInfo(_Weakrefable):
the descriptor for this flight.
endpoints : list of FlightEndpoint
a list of endpoints where this flight is available.
total_records : int
total_records : int optional, default None
the total records in this flight, or -1 if unknown.
total_bytes : int
total_bytes : int optional, default None
the total bytes in this flight, or -1 if unknown.
ordered : boolean optional, default False
Whether endpoints are in the same order as the data.
Expand All @@ -878,8 +878,8 @@ cdef class FlightInfo(_Weakrefable):
check_flight_status(CreateFlightInfo(c_schema,
descriptor.descriptor,
c_endpoints,
total_records,
total_bytes,
total_records if total_records is not None else -1,
total_bytes if total_bytes is not None else -1,
ordered,
tobytes(app_metadata), &self.info))

Expand Down
36 changes: 16 additions & 20 deletions python/pyarrow/tests/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ def list_flights(self, context, criteria):
yield flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path('/foo'),
[],
-1, -1
[]
)

def do_get(self, context, ticket):
Expand Down Expand Up @@ -938,56 +937,54 @@ def test_eq():
lambda: (
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(), [], -1, -1),
flight.FlightDescriptor.for_path(), []),
flight.FlightInfo(
pa.schema([("ints", pa.int64())]),
flight.FlightDescriptor.for_path(), [], -1, -1)),
flight.FlightDescriptor.for_path(), [])),
lambda: (
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(), [], -1, -1),
flight.FlightDescriptor.for_path(), []),
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_command(b"foo"), [], -1, -1)),
flight.FlightDescriptor.for_command(b"foo"), [])),
lambda: (
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(),
[flight.FlightEndpoint(b"foo", [])],
-1, -1),
[flight.FlightEndpoint(b"foo", [])]),
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(),
[flight.FlightEndpoint(b"bar", [])],
-1, -1)),
[flight.FlightEndpoint(b"bar", [])])),
lambda: (
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(), [], -1, -1),
flight.FlightDescriptor.for_path(), [], total_records=-1),
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(), [], 1, -1)),
flight.FlightDescriptor.for_path(), [], total_records=1)),
lambda: (
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(), [], -1, -1),
flight.FlightDescriptor.for_path(), [], total_bytes=-1),
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(), [], -1, 42)),
flight.FlightDescriptor.for_path(), [], total_bytes=42)),
lambda: (
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(), [], -1, -1, False),
flight.FlightDescriptor.for_path(), [], ordered=False),
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(), [], -1, -1, True)),
flight.FlightDescriptor.for_path(), [], ordered=True)),
lambda: (
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(), [], -1, -1, app_metadata=b""),
flight.FlightDescriptor.for_path(), [], app_metadata=b""),
flight.FlightInfo(
pa.schema([]),
flight.FlightDescriptor.for_path(), [], -1, -1, app_metadata=b"meta")),
flight.FlightDescriptor.for_path(), [], app_metadata=b"meta")),
lambda: (flight.Location("grpc+tcp://localhost:1234"),
flight.Location("grpc+tls://localhost:1234")),
lambda: (flight.Result(b"foo"), flight.Result(b"bar")),
Expand Down Expand Up @@ -2411,8 +2408,7 @@ def get_flight_info(self, context, descriptor):
return flight.FlightInfo(
pa.schema([]),
descriptor,
[],
-1, -1
[]
)

class HeadersTrailersMiddlewareFactory(ClientMiddlewareFactory):
Expand Down
4 changes: 1 addition & 3 deletions python/pyarrow/tests/test_flight_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ class ExampleServer(flight.FlightServerBase):
simple_info = flight.FlightInfo(
pyarrow.schema([("a", "int32")]),
flight.FlightDescriptor.for_command(b"simple"),
[],
-1,
-1
[]
)

def get_flight_info(self, context, descriptor):
Expand Down

0 comments on commit 2c8bfa3

Please sign in to comment.