diff --git a/StabilityMatrix.Core/Models/Packages/Cogstudio.cs b/StabilityMatrix.Core/Models/Packages/Cogstudio.cs index 1e946cca..4d0b0c8c 100644 --- a/StabilityMatrix.Core/Models/Packages/Cogstudio.cs +++ b/StabilityMatrix.Core/Models/Packages/Cogstudio.cs @@ -65,15 +65,31 @@ public override async Task InstallPackage( progress?.Report(new ProgressReport(-1f, "Setting up Cogstudio files", isIndeterminate: true)); var gradioCompositeDemo = new FilePath(installLocation, "inference/gradio_composite_demo"); + var cogstudioFile = new FilePath(gradioCompositeDemo, "cogstudio.py"); gradioCompositeDemo.Directory?.Create(); await DownloadService - .DownloadToFileAsync( - cogstudioUrl, - new FilePath(gradioCompositeDemo, "cogstudio.py"), - cancellationToken: cancellationToken - ) + .DownloadToFileAsync(cogstudioUrl, cogstudioFile, cancellationToken: cancellationToken) .ConfigureAwait(false); + progress?.Report( + new ProgressReport( + -1f, + "Patching cogstudio.py to allow writing to the output folder", + isIndeterminate: true + ) + ); + var outputDir = new FilePath(installLocation, "output"); + if (Compat.IsWindows) + { + outputDir = outputDir.ToString().Replace("\\", "\\\\"); + } + var cogstudioContent = await cogstudioFile.ReadAllTextAsync(cancellationToken).ConfigureAwait(false); + cogstudioContent = cogstudioContent.Replace( + "demo.launch()", + $"demo.launch(allowed_paths=['{outputDir}'])" + ); + await cogstudioFile.WriteAllTextAsync(cogstudioContent, cancellationToken).ConfigureAwait(false); + progress?.Report(new ProgressReport(-1f, "Installing requirements", isIndeterminate: true)); var requirements = new FilePath(installLocation, "requirements.txt"); var pipArgs = new PipInstallArgs()