Skip to content

Commit

Permalink
add tests for handlers cleanly exiting (#267)
Browse files Browse the repository at this point in the history
## Why

<!-- Describe what you are trying to accomplish with this pull request
-->

## What changed

<!-- Describe the changes you made in this pull request or pointers for
the reviewer -->

## Versioning

- [ ] Breaking protocol change
- [ ] Breaking ts/js API change

<!-- Kind reminder to add tests and updated documentation if needed -->
  • Loading branch information
jackyzha0 authored Sep 5, 2024
1 parent 4f610d4 commit 1bc1ce9
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 13 deletions.
217 changes: 216 additions & 1 deletion __tests__/cancellation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,222 @@ function makeMockHandler<T extends ValidProcType>(
>(impl);
}

describe.each(testMatrix(['ws', 'naive']))(
describe.each(testMatrix())(
'clean handler cancellation ($transport.name transport, $codec.name codec)',

async ({ transport, codec }) => {
const opts = { codec: codec.codec };

const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups();
let getClientTransport: TestSetupHelpers['getClientTransport'];
let getServerTransport: TestSetupHelpers['getServerTransport'];
beforeEach(async () => {
const setup = await transport.setup({ client: opts, server: opts });
getClientTransport = setup.getClientTransport;
getServerTransport = setup.getServerTransport;

return async () => {
await postTestCleanup();
await setup.cleanup();
};
});

describe('e2e', () => {
test('rpc', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();

const signalReceiver = vi.fn<(sig: AbortSignal) => void>();
const services = {
service: ServiceSchema.define({
rpc: Procedure.rpc({
requestInit: Type.Object({}),
responseData: Type.Object({}),
handler: async ({ ctx }) => {
signalReceiver(ctx.signal);

return Ok({});
},
}),
}),
};

const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

await client.service.rpc.rpc({});

await waitFor(() => {
expect(signalReceiver).toHaveBeenCalledTimes(1);
});

const [sig] = signalReceiver.mock.calls[0];
expect(sig.aborted).toEqual(true);

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});

test('stream', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

const signalReceiver = vi.fn<(sig: AbortSignal) => void>();
const services = {
service: ServiceSchema.define({
stream: Procedure.stream({
requestInit: Type.Object({}),
requestData: Type.Object({}),
responseData: Type.Object({}),
handler: async ({ ctx, resWritable }) => {
signalReceiver(ctx.signal);

resWritable.write(Ok({}));
resWritable.close();

return;
},
}),
}),
};

const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

const { reqWritable, resReadable } = client.service.stream.stream({});

await waitFor(() => {
expect(signalReceiver).toHaveBeenCalledTimes(1);
});

const [sig] = signalReceiver.mock.calls[0];
expect(sig.aborted).toEqual(false);

reqWritable.close();
await waitFor(() => expect(sig.aborted).toEqual(true));

// collect should resolve as the stream has been properly ended
await expect(resReadable.collect()).resolves.toEqual([Ok({})]);

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});

test('upload', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

const signalReceiver = vi.fn<(sig: AbortSignal) => void>();
const services = {
service: ServiceSchema.define({
upload: Procedure.upload({
requestInit: Type.Object({}),
requestData: Type.Object({}),
responseData: Type.Object({}),
handler: async ({ ctx }) => {
signalReceiver(ctx.signal);

return Ok({});
},
}),
}),
};

const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

const { reqWritable, finalize } = client.service.upload.upload({});

await waitFor(() => {
expect(signalReceiver).toHaveBeenCalledTimes(1);
});

const [sig] = signalReceiver.mock.calls[0];
expect(sig.aborted).toEqual(false);

reqWritable.close();
await waitFor(() => expect(sig.aborted).toEqual(true));

expect(await finalize()).toEqual(Ok({}));

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});

test('subscribe', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

const signalReceiver = vi.fn<(sig: AbortSignal) => void>();
const services = {
service: ServiceSchema.define({
subscribe: Procedure.subscription({
requestInit: Type.Object({}),
responseData: Type.Object({}),
handler: async ({ ctx, resWritable }) => {
resWritable.close();
signalReceiver(ctx.signal);

return;
},
}),
}),
};

const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

const { resReadable } = client.service.subscribe.subscribe({});

await waitFor(() => {
expect(signalReceiver).toHaveBeenCalledTimes(1);
});

const [sig] = signalReceiver.mock.calls[0];
expect(sig.aborted).toEqual(true);
await expect(resReadable.collect()).resolves.toEqual([]);

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});
});
},
);

describe.each(testMatrix())(
'client initiated cancellation ($transport.name transport, $codec.name codec)',
async ({ transport, codec }) => {
const opts = { codec: codec.codec };
Expand Down
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@replit/river",
"description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!",
"version": "0.200.4",
"version": "0.200.5",
"type": "module",
"exports": {
".": {
Expand Down
15 changes: 6 additions & 9 deletions router/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,10 @@ class RiverServer<Services extends AnyServiceSchemaMap>
}

// if its not a cancelled stream, validate and create a new stream
const newStream = this.createNewProcStream({
this.createNewProcStream({
...newStreamProps,
...message,
});
this.streams.set(streamId, newStream);
};

const handleSessionStatus = (evt: EventMap['sessionStatus']) => {
Expand Down Expand Up @@ -235,7 +234,7 @@ class RiverServer<Services extends AnyServiceSchemaMap>
this.transport.addEventListener('transportStatus', handleTransportStatus);
}

private createNewProcStream(props: StreamInitProps): ProcStream {
private createNewProcStream(props: StreamInitProps) {
const {
streamId,
initialSession,
Expand Down Expand Up @@ -369,6 +368,7 @@ class RiverServer<Services extends AnyServiceSchemaMap>
});
};

const finishedController = new AbortController();
const procStream: ProcStream = {
from: from,
streamId,
Expand Down Expand Up @@ -422,7 +422,6 @@ class RiverServer<Services extends AnyServiceSchemaMap>
cancelStream(streamId, result);
};

const finishedController = new AbortController();
const cleanup = () => {
finishedController.abort();
this.streams.delete(streamId);
Expand Down Expand Up @@ -531,10 +530,6 @@ class RiverServer<Services extends AnyServiceSchemaMap>
// only consists of an init message and we shouldn't expect follow up data
if (procClosesWithInit) {
closeReadable();
} else if (procedure.type === 'rpc' || procedure.type === 'subscription') {
// Though things can work just fine if they eventually follow up with a stream
// control message with a close bit set, it's an unusual client implementation!
this.log?.warn('sent an init without a stream close', loggingMetadata);
}

const handlerContext: ProcedureHandlerContext<object> = {
Expand Down Expand Up @@ -658,7 +653,9 @@ class RiverServer<Services extends AnyServiceSchemaMap>
break;
}

return procStream;
if (!finishedController.signal.aborted) {
this.streams.set(streamId, procStream);
}
}

private getContext(service: AnyService, serviceName: string) {
Expand Down

0 comments on commit 1bc1ce9

Please sign in to comment.