Skip to content

Commit

Permalink
Auto-install the imported packages (#902)
Browse files Browse the repository at this point in the history
* Add autoInstall option to the kernel

* Run pyodide.loadPackagesFromImports() at init and file writing

* Send the auto-installed packages to the editor and update the requirements tab content

* Update the auto install event emits a promise to resolve the installed package list

* Update to call loadPackage() only when needed

* Improve the toast notification

* Fix

* Fix

* Update the editor content

* Fix

* Implement module auto-load at each run and split the autoInstall option into moduleAutoLoadOnRun and moduleAutoLoadOnSave

* Revert "Implement module auto-load at each run and split the autoInstall option into moduleAutoLoadOnRun and moduleAutoLoadOnSave"

This reverts commit 61cfd1c.

* Rename auto-install to module-auto-load

* Update findImports impl

* Fix to await the auto-load promise in the script runner

* Fix

* Rename messages

* Fix

* Apply formatter

* Add comment

* Refactoring

* Refactoring

* Fix inter-window messaging

* Fix

* Fix

* Fix <Editor /> to handle addRequirements() in an imperative way
  • Loading branch information
whitphx authored Jun 5, 2024
1 parent 5a5e5e1 commit 0c6760d
Show file tree
Hide file tree
Showing 14 changed files with 607 additions and 256 deletions.
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
import React from "react";
import { StliteKernel } from "@stlite/kernel";
import { toast } from "react-toastify";
import type { StliteKernel, StliteKernelOptions } from "@stlite/kernel";
import { toast, ToastPromiseParams } from "react-toastify";
import ErrorToastContent from "./ErrorToastContent";

type ToastPromiseParams<TData> = Parameters<typeof toast.promise<TData>>;
type ToastPromiseMessages<TData> = Partial<
Record<keyof ToastPromiseParams<TData>[1], string>
>;
type ToastPromiseReturnType = ReturnType<typeof toast.promise>;
function stliteStyledPromiseToast<TData>(
promise: ToastPromiseParams<TData>[0],
messages: ToastPromiseMessages<TData>
): ToastPromiseReturnType {
function stliteStyledPromiseToast<
TData = unknown,
TError extends Error | undefined = undefined,
TPending = unknown
>(
promise: Promise<TData>,
messages: ToastPromiseParams<TData, TError, TPending>
): ReturnType<typeof toast.promise> {
const errorMessage = messages.error;
return toast.promise<TData, Error>(
return toast.promise<TData, TError, TPending>(
promise,
{
pending: messages.pending,
success: messages.success,
error: errorMessage && {
render({ data }) {
return data ? (
<ErrorToastContent message={errorMessage} error={data} />
) : (
messages.error
);
},
autoClose: false,
closeOnClick: false,
},
error:
typeof errorMessage === "string"
? {
render({ data }) {
return data ? (
<ErrorToastContent message={errorMessage} error={data} />
) : (
<>messages.error</>
);
},
autoClose: false,
closeOnClick: false,
}
: errorMessage,
},
{
hideProgressBar: true,
Expand All @@ -37,8 +39,32 @@ function stliteStyledPromiseToast<TData>(
);
}

export interface StliteKernelWithToastOptions {
onModuleAutoLoad?: StliteKernelOptions["onModuleAutoLoad"];
}
export class StliteKernelWithToast {
constructor(private kernel: StliteKernel) {}
constructor(
private kernel: StliteKernel,
options?: StliteKernelWithToastOptions
) {
kernel.onModuleAutoLoad = (packagesToLoad, installPromise) => {
if (options?.onModuleAutoLoad) {
options.onModuleAutoLoad(packagesToLoad, installPromise);
}

stliteStyledPromiseToast(installPromise, {
success: {
render({ data }) {
return `Auto-loaded${
data ? ": " + data.map((pkg) => pkg.name).join(", ") : " packages"
}`;
},
},
error: "Failed to auto-load packages",
pending: "Auto-loading packages",
});
};
}

public writeFile(...args: Parameters<StliteKernel["writeFile"]>) {
return stliteStyledPromiseToast<void>(this.kernel.writeFile(...args), {
Expand Down
9 changes: 6 additions & 3 deletions packages/desktop/electron/worker.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import { parentPort } from "node:worker_threads";
import { startWorkerEnv } from "@stlite/kernel/src/worker-runtime";
import {
startWorkerEnv,
type PostMessageFn,
} from "@stlite/kernel/src/worker-runtime";
import { loadNodefsMountpoints } from "./worker-options";

function postMessage(value: any) {
const postMessage: PostMessageFn = (value) => {
console.debug("[worker thread] postMessage from worker", value);
parentPort?.postMessage(value);
}
};

const handleMessage = startWorkerEnv(
process.env.PYODIDE_URL as string,
Expand Down
46 changes: 39 additions & 7 deletions packages/kernel/src/kernel.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Ref: https://github.com/jupyterlite/jupyterlite/blob/f2ecc9cf7189cb19722bec2f0fc7ff5dfd233d47/packages/pyolite-kernel/src/kernel.ts

import type { PackageData } from "pyodide";
import { PromiseDelegate } from "@stlite/common";

import type { IHostConfigResponse } from "@streamlit/lib/src/hostComm/types";
Expand All @@ -21,6 +22,7 @@ import type {
StliteWorker,
WorkerInitialData,
StreamlitConfig,
ModuleAutoLoadMessage,
} from "./types";
import { assertStreamlitConfig } from "./types";

Expand Down Expand Up @@ -117,6 +119,13 @@ export interface StliteKernelOptions {

idbfsMountpoints?: WorkerInitialData["idbfsMountpoints"];

moduleAutoLoad?: WorkerInitialData["moduleAutoLoad"];

onModuleAutoLoad?: (
packagesToLoad: string[],
installPromise: Promise<PackageData[]>
) => void;

onProgress?: (message: string) => void;

onLoad?: () => void;
Expand All @@ -143,11 +152,10 @@ export class StliteKernel {

public readonly hostConfigResponse: IHostConfigResponse; // Will be passed to ConnectionManager to call `onHostConfigResp` from it.

private onProgress: StliteKernelOptions["onProgress"];

private onLoad: StliteKernelOptions["onLoad"];

private onError: StliteKernelOptions["onError"];
public onProgress: StliteKernelOptions["onProgress"];
public onLoad: StliteKernelOptions["onLoad"];
public onError: StliteKernelOptions["onError"];
public onModuleAutoLoad: StliteKernelOptions["onModuleAutoLoad"];

constructor(options: StliteKernelOptions) {
this.basePath = (options.basePath ?? window.location.pathname)
Expand All @@ -157,6 +165,7 @@ export class StliteKernel {
this.onProgress = options.onProgress;
this.onLoad = options.onLoad;
this.onError = options.onError;
this.onModuleAutoLoad = options.onModuleAutoLoad;

if (options.worker) {
this._worker = options.worker;
Expand All @@ -168,7 +177,8 @@ export class StliteKernel {
}

this._worker.onmessage = (e) => {
this._processWorkerMessage(e.data);
const messagePort: MessagePort | undefined = e.ports[0];
this._processWorkerMessage(e.data, messagePort);
};

let wheels: WorkerInitialData["wheels"] = undefined;
Expand Down Expand Up @@ -209,6 +219,7 @@ export class StliteKernel {
options.mountedSitePackagesSnapshotFilePath,
streamlitConfig: options.streamlitConfig,
idbfsMountpoints: options.idbfsMountpoints,
moduleAutoLoad: options.moduleAutoLoad ?? false,
};
}

Expand Down Expand Up @@ -337,7 +348,7 @@ export class StliteKernel {
*
* @param msg The worker message to process.
*/
private _processWorkerMessage(msg: OutMessage): void {
private _processWorkerMessage(msg: OutMessage, port?: MessagePort): void {
switch (msg.type) {
case "event:start": {
this._worker.postMessage({
Expand All @@ -364,6 +375,27 @@ export class StliteKernel {
this.handleWebSocketMessage && this.handleWebSocketMessage(payload);
break;
}
case "event:moduleAutoLoad": {
if (port == null) {
throw new Error("Port is required for moduleAutoLoad event");
}
this.onModuleAutoLoad &&
this.onModuleAutoLoad(
msg.data.packagesToLoad,
new Promise((resolve, reject) => {
port.onmessage = (e) => {
const msg: ModuleAutoLoadMessage = e.data;
if (msg.type === "moduleAutoLoad:success") {
resolve(msg.data.loadedPackages);
} else {
reject(msg.error);
}
port.close();
};
})
);
break;
}
}
}

Expand Down
71 changes: 71 additions & 0 deletions packages/kernel/src/module-auto-load.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import type { PackageData, PyodideInterface } from "pyodide";

Check warning on line 1 in packages/kernel/src/module-auto-load.ts

View workflow job for this annotation

GitHub Actions / test-kernel

'PackageData' is defined but never used
import type { ModuleAutoLoadMessage } from "./types";
import type { PostMessageFn } from "./worker-runtime";

function findImports(pyodide: PyodideInterface, source: string): string[] {
return pyodide.pyodide_py.ffi._pyodide._base
.find_imports(source)
.toJs() as string[];
}

export async function tryModuleAutoLoad(
pyodide: PyodideInterface,
postMessage: PostMessageFn,
sources: string[]
): Promise<void> {
// Ref: `pyodide.loadPackagesFromImports` (https://github.com/pyodide/pyodide/blob/0.26.0/src/js/api.ts#L191)

const importsArr = sources.map((source) => findImports(pyodide, source));
const imports = Array.from(new Set(importsArr.flat()));

const notFoundImports = imports.filter(
(name) =>
!pyodide.runPython(`__import__('importlib').util.find_spec('${name}')`)
);

const packagesToLoad = notFoundImports
.map((name) =>
(
pyodide as unknown as {
_api: { _import_name_to_package_name: Map<string, string> };
}
)._api._import_name_to_package_name.get(name)
)
.filter((name) => name) as string[];

if (packagesToLoad.length === 0) {
return;
}

const channel = new MessageChannel();

postMessage(
{
type: "event:moduleAutoLoad",
data: {
packagesToLoad,
},
},
channel.port2
);

try {
const loadedPackages = await pyodide.loadPackage(packagesToLoad);

channel.port1.postMessage({
type: "moduleAutoLoad:success",
data: {
loadedPackages,
},
} as ModuleAutoLoadMessage);
channel.port1.close();
return;
} catch (error) {
channel.port1.postMessage({
type: "moduleAutoLoad:error",
error: error as Error,
} as ModuleAutoLoadMessage);
channel.port1.close();
throw error;
}
}
27 changes: 25 additions & 2 deletions packages/kernel/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { PyodideInterface } from "pyodide";
import type { PyodideInterface, PackageData } from "pyodide";

export type PyodideConvertiblePrimitive =
| string
Expand Down Expand Up @@ -57,6 +57,7 @@ export interface WorkerInitialData {
streamlitConfig?: StreamlitConfig;
idbfsMountpoints?: string[];
nodefsMountpoints?: Record<string, string>;
moduleAutoLoad: boolean;
}

/**
Expand Down Expand Up @@ -161,12 +162,34 @@ export interface OutMessageWebSocketBack extends OutMessageBase {
payload: Uint8Array | string;
};
}
export interface OutMessageModuleAutoLoadEvent extends OutMessageBase {
type: "event:moduleAutoLoad";
data: {
packagesToLoad: string[];
};
}
export type OutMessage =
| OutMessageStartEvent
| OutMessageProgressEvent
| OutMessageErrorEvent
| OutMessageLoadedEvent
| OutMessageWebSocketBack;
| OutMessageWebSocketBack
| OutMessageModuleAutoLoadEvent;

export interface ModuleAutoLoadMessageBase {
type: string;
}
export interface ModuleAutoLoadSuccess extends ModuleAutoLoadMessageBase {
type: "moduleAutoLoad:success";
data: {
loadedPackages: PackageData[];
};
}
export interface ModuleAutoLoadError extends ModuleAutoLoadMessageBase {
type: "moduleAutoLoad:error";
error: Error;
}
export type ModuleAutoLoadMessage = ModuleAutoLoadSuccess | ModuleAutoLoadError;

/**
* Reply message to InMessage
Expand Down
Loading

0 comments on commit 0c6760d

Please sign in to comment.