Skip to content

Commit

Permalink
Support anthropic PDF documents and citations API
Browse files Browse the repository at this point in the history
  • Loading branch information
abrenneke committed Feb 7, 2025
1 parent be12e98 commit 74edaec
Show file tree
Hide file tree
Showing 13 changed files with 635 additions and 53 deletions.
38 changes: 31 additions & 7 deletions packages/app/src/components/NodeOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ import { type FC, type ReactNode, memo, useMemo, useState, type MouseEvent } fro
import { useUnknownNodeComponentDescriptorFor } from '../hooks/useNodeTypes.js';
import { useStableCallback } from '../hooks/useStableCallback.js';
import { copyToClipboard } from '../utils/copyToClipboard.js';
import { type ChartNode, type PortId, type ProcessId, getWarnings, type Outputs } from '@ironclad/rivet-core';
import {
type ChartNode,
type PortId,
type ProcessId,
getWarnings,
type Outputs,
type ChatMessageDataValue,
} from '@ironclad/rivet-core';
import { css } from '@emotion/react';
import CopyIcon from 'majesticons/line/clipboard-line.svg?react';
import ExpandIcon from 'majesticons/line/maximize-line.svg?react';
Expand All @@ -27,6 +34,7 @@ import Toggle from '@atlaskit/toggle';
import { pinnedNodesState } from '../state/graphBuilder';
import { useNodeIO } from '../hooks/useGetNodeIO';
import { Tooltip } from './Tooltip';
import { getGlobalDataRef } from '../utils/globals';

export const NodeOutput: FC<{ node: ChartNode }> = memo(({ node }) => {
const [isModalOpen, setIsModalOpen] = useState(false);
Expand Down Expand Up @@ -217,11 +225,19 @@ const NodeFullscreenOutput: FC<{ node: ChartNode }> = ({ node }) => {
if (outputValue.type === 'string') {
copyToClipboard(outputValue.value);
} else if (outputValue.type === 'chat-message') {
if (Array.isArray(outputValue.value)) {
const singleString = outputValue.value.map((v) => (typeof v === 'string' ? v : '(Image)')).join('\n\n');
const resolved = getGlobalDataRef(outputValue.value.ref);

if (!resolved) {
return;
}

const chatMessage = resolved as ChatMessageDataValue | ChatMessageDataValue[];

if (Array.isArray(chatMessage)) {
const singleString = chatMessage.map((v) => (typeof v === 'string' ? v : '(Image)')).join('\n\n');
copyToClipboard(singleString);
} else {
copyToClipboard(typeof outputValue.value.message === 'string' ? outputValue.value.message : '(Image)');
copyToClipboard(typeof chatMessage.value.message === 'string' ? chatMessage.value.message : '(Image)');
}
} else {
copyToClipboard(JSON.stringify(outputValue, null, 2));
Expand Down Expand Up @@ -408,11 +424,19 @@ const NodeOutputSingleProcess: FC<{
if (outputValue.type === 'string') {
copyToClipboard(outputValue.value);
} else if (outputValue.type === 'chat-message') {
if (Array.isArray(outputValue.value)) {
const singleString = outputValue.value.map((v) => (typeof v === 'string' ? v : '(Image)')).join('\n\n');
const resolved = getGlobalDataRef(outputValue.value.ref);

if (!resolved) {
return;
}

const chatMessage = resolved as ChatMessageDataValue | ChatMessageDataValue;

if (Array.isArray(chatMessage.value)) {
const singleString = chatMessage.value.map((v) => (typeof v === 'string' ? v : '(Image)')).join('\n\n');
copyToClipboard(singleString);
} else {
copyToClipboard(typeof outputValue.value.message === 'string' ? outputValue.value.message : '(Image)');
copyToClipboard(typeof chatMessage.value.message === 'string' ? chatMessage.value.message : '(Image)');
}
} else {
copyToClipboard(JSON.stringify(outputValue, null, 2));
Expand Down
52 changes: 49 additions & 3 deletions packages/app/src/components/RenderDataValue.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import {
type AudioDataValue,
isArrayDataType,
type ScalarOrArrayDataValue,
type DocumentDataValue,
type ChatMessageDataValue,
} from '@ironclad/rivet-core';
import { css } from '@emotion/react';
import { keys } from '../../../core/src/utils/typeSafety';
Expand All @@ -23,6 +25,7 @@ import { P, match } from 'ts-pattern';
import clsx from 'clsx';
import { type InputsOrOutputsWithRefs, type DataValueWithRefs, type ScalarDataValueWithRefs } from '../state/dataFlow';
import { getGlobalDataRef } from '../utils/globals';
import prettyBytes from 'pretty-bytes';

const styles = css`
.chat-message.user header em {
Expand Down Expand Up @@ -108,9 +111,17 @@ const scalarRenderers: {
return <pre className="pre-wrap">{truncated}</pre>;
},
'chat-message': ({ value, renderMarkdown }) => {
const parts = Array.isArray(value.value.message) ? value.value.message : [value.value.message];
const resolved = getGlobalDataRef(value.value.ref);

if (!resolved) {
return <div>Could not find data.</div>;
}

const { value: realValue } = resolved as ChatMessageDataValue;

const parts = Array.isArray(realValue.message) ? realValue.message : [realValue.message];

const message = value.value;
const message = realValue;

const messageContent = (
<div className="message-content">
Expand Down Expand Up @@ -276,6 +287,27 @@ const scalarRenderers: {
'graph-reference': ({ value }) => {
return <div>(Reference to graph &quot;{value.value.graphName}&quot;)</div>;
},
document: ({ value }) => {
const resolved = getGlobalDataRef(value.value.ref);
if (!resolved) {
return <div>Could not find data.</div>;
}

const {
value: { context, data, title, enableCitations, mediaType },
} = resolved as DocumentDataValue;

return (
<div>
<p>
{title ? `Document: ${title}` : 'Document'} ({mediaType})
</p>
{context && <p>{context}</p>}
{enableCitations && <p>(Citations enabled)</p>}
Size: {data.length > 0 ? prettyBytes(data.length) : '0 bytes'}
</div>
);
},
};
/* eslint-enable react-hooks/rules-of-hooks -- These are components (ish) */

Expand All @@ -301,6 +333,20 @@ const RenderChatMessagePart: FC<{ part: ChatMessageMessagePart; renderMarkdown?:
.with({ type: 'url' }, (part) => {
return <img className="chat-message-url-image" src={part.url} alt={part.url} />;
})
.with({ type: 'document' }, (part) => {
const { data, mediaType, context, title, enableCitations } = part;

return (
<div>
<p>
{title ? `Document: ${title}` : 'Document'} ({mediaType})
</p>
{context && <p>{context}</p>}
{enableCitations && <p>(Citations enabled)</p>}
Size: {data.length > 0 ? prettyBytes(data.length) : '0 bytes'}
</div>
);
})
.exhaustive();
};

Expand All @@ -317,9 +363,9 @@ export const RenderDataValue: FC<{
return <>undefined</>;
}

const keys = Object.keys(value?.value ?? {});
if (isArrayDataType(value.type)) {
const items = arrayizeDataValue(value as ScalarOrArrayDataValue);

return (
<div
css={multiOutput}
Expand Down
6 changes: 3 additions & 3 deletions packages/app/src/components/VisualNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -764,9 +764,9 @@ const NormalVisualNodeContent: FC<{
/>
)}

<ErrorBoundary fallback={<div>Error rendering node output</div>}>
<NodeOutput node={node} />
</ErrorBoundary>
{/* <ErrorBoundary fallback={<div>Error rendering node output</div>}> */}
<NodeOutput node={node} />
{/* </ErrorBoundary> */}
<div className="node-resize">
<ResizeHandle onResizeStart={handleResizeStart} onResizeMove={handleResizeMove} />
</div>
Expand Down
126 changes: 119 additions & 7 deletions packages/app/src/hooks/useCurrentExecution.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ import {
coerceTypeOptional,
getScalarTypeOf,
isArrayDataValue,
isScalarDataType,
isScalarDataValue,
arrayizeDataValue,
type ChatMessageMessagePart,
} from '@ironclad/rivet-core';
import { produce } from 'immer';
import { cloneDeep, mapValues } from 'lodash-es';
Expand All @@ -31,7 +35,7 @@ import { lastRecordingState } from '../state/execution';
import { trivetTestsRunningState } from '../state/trivet';
import { useLatest } from 'ahooks';
import { entries, keys } from '../../../core/src/utils/typeSafety';
import { match } from 'ts-pattern';
import { P, match } from 'ts-pattern';
import { previousDataPerNodeToKeepState } from '../state/settings';
import { nanoid } from 'nanoid';
import { setGlobalDataRef } from '../utils/globals';
Expand Down Expand Up @@ -91,6 +95,95 @@ function sanitizeDataValueForLength(value: DataValue | undefined) {
.otherwise((value): DataValue | undefined => value);
}

export function fixDataValueUint8Arrays(value: DataValue | undefined): DataValue | undefined {
if (!value) {
return undefined;
}

if (isArrayDataValue(value)) {
const arrayized = arrayizeDataValue(value);

const fixed = arrayized.map((val) => fixDataValueUint8Arrays(val));

return {
...value,
value: fixed.map((v) => v!.value),
} as DataValue;
}

const fix = (value: Uint8Array | object) =>
value instanceof Uint8Array ? value : Uint8Array.from(Object.values(value));

const fixed = match(value)
.with({ type: 'binary' }, (value): DataValue => {
return {
...value,
value: fix(value.value),
};
})
.with({ type: 'audio' }, (value): DataValue => {
return {
...value,
value: {
...value.value,
data: fix(value.value.data),
},
};
})
.with({ type: 'document' }, (value): DataValue => {
return {
...value,
value: {
...value.value,
data: fix(value.value.data),
},
};
})
.with({ type: 'image' }, (value): DataValue => {
return {
...value,
value: {
...value.value,
data: fix(value.value.data),
},
};
})
.with({ type: 'chat-message' }, (value): DataValue => {
if (Array.isArray(value.value.message)) {
return {
...value,
value: {
...value.value,
message: value.value.message.map((part) => fixChatMessagePartUint8Arrays(part)),
},
};
}

return {
...value,
value: {
...value.value,
message: fixChatMessagePartUint8Arrays(value.value.message),
},
};
})
.otherwise((value): DataValue => value);

return fixed;
}

function fixChatMessagePartUint8Arrays(part: ChatMessageMessagePart): ChatMessageMessagePart {
return match(part)
.with(P.string, (part) => part)
.with({ type: 'document' }, (part) => {
return {
...part,
data: Uint8Array.from(Object.values(part.data)),
};
})
.otherwise((part) => part);
}

function cloneNodeDataForHistory(data: Partial<NodeRunData>): Partial<NodeRunDataWithRefs> {
return {
...data,
Expand Down Expand Up @@ -120,13 +213,30 @@ function cloneNodeInputOrOutputDataForHistory(data: Inputs | Outputs | undefined

function convertToRef(value: DataValue): DataValueWithRefs {
const scalarType = getScalarTypeOf(value.type);
if (scalarType !== 'audio' && scalarType !== 'binary' && scalarType !== 'image') {
if (
scalarType !== 'audio' &&
scalarType !== 'binary' &&
scalarType !== 'image' &&
scalarType !== 'document' &&
scalarType !== 'chat-message'
) {
return cloneDeep(value) as DataValueWithRefs;
}

const refId = nanoid();
setGlobalDataRef(refId, value);
return { type: value.type, value: { ref: refId } } as DataValueWithRefs;
if (isScalarDataValue(value)) {
const refId = nanoid();
setGlobalDataRef(refId, value);
return { type: value.type, value: { ref: refId } } as DataValueWithRefs;
} else if (isArrayDataValue(value)) {
const mappedValues = value.value.map((val) => {
const asRef = convertToRef({ type: getScalarTypeOf(value.type), value: val } as DataValue);
return asRef.value;
});

return { type: value.type, value: mappedValues } as DataValueWithRefs;
} else {
return cloneDeep(value) as DataValueWithRefs;
}
}

export function useCurrentExecution() {
Expand Down Expand Up @@ -190,7 +300,8 @@ export function useCurrentExecution() {
const onNodeStart = ({ node, inputs, processId }: ProcessEvents['nodeStart']) => {
const sanitizedInputs: Inputs = {};
for (const [key, value] of entries(inputs)) {
sanitizedInputs[key] = sanitizeDataValueForLength(value) as DataValue;
const uint8ArrayFixed = fixDataValueUint8Arrays(value) as DataValue;
sanitizedInputs[key] = sanitizeDataValueForLength(uint8ArrayFixed) as DataValue;
}

setDataForNode(node.id, processId, {
Expand All @@ -204,7 +315,8 @@ export function useCurrentExecution() {
const onNodeFinish = ({ node, outputs, processId }: ProcessEvents['nodeFinish']) => {
const sanitizedOutputs: Outputs = {};
for (const [key, value] of entries(outputs)) {
sanitizedOutputs[key] = sanitizeDataValueForLength(value) as DataValue;
const uint8ArrayFixed = fixDataValueUint8Arrays(value) as DataValue;
sanitizedOutputs[key] = sanitizeDataValueForLength(uint8ArrayFixed) as DataValue;
}

setDataForNode(node.id, processId, {
Expand Down
4 changes: 3 additions & 1 deletion packages/app/src/state/dataFlow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ export type InputsOrOutputsWithRefs = {
export type DataValueWithRefs = {
[P in DataType]: {
type: P;
value: P extends 'binary' | 'audio' | 'image' ? { ref: string } : Extract<DataValue, { type: P }>['value'];
value: P extends 'binary' | 'audio' | 'image' | 'document' | 'chat-message'
? { ref: string }
: Extract<DataValue, { type: P }>['value'];
};
}[DataType];

Expand Down
Loading

0 comments on commit 74edaec

Please sign in to comment.