diff --git a/StabilityMatrix.Avalonia/DesignData/DesignData.cs b/StabilityMatrix.Avalonia/DesignData/DesignData.cs index 7874e75a..f609ecf8 100644 --- a/StabilityMatrix.Avalonia/DesignData/DesignData.cs +++ b/StabilityMatrix.Avalonia/DesignData/DesignData.cs @@ -1167,6 +1167,24 @@ public static CompletionList SampleCompletionList } }; + public static SafetensorMetadataViewModel SafetensorMetadataViewModel => + DialogFactory.Get(vm => + { + vm.Metadata = new SafetensorMetadata + { + TagFrequency = Enumerable + .Range(1, 100) + .Select(i => new SafetensorMetadata.Tag("tag" + i, i)) + .ToList(), + OtherMetadata = new List + { + new("Name1", "Value1"), + new("Name2", "Value2"), + new("Name3", "Value3"), + } + }; + }); + public static ModelMetadataEditorDialogViewModel MetadataEditorDialogViewModel => DialogFactory.Get(vm => { diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFileViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFileViewModel.cs index ce61361b..2f225f85 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFileViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointManager/CheckpointFileViewModel.cs @@ -307,6 +307,55 @@ private async Task RenameAsync() } } + [RelayCommand] + private async Task OpenSafetensorMetadataViewer() + { + if (!CheckpointFile.SafetensorMetadataParsed) + { + if ( + !settingsManager.IsLibraryDirSet + || new DirectoryPath(settingsManager.ModelsDirectory) is not { Exists: true } modelsDir + ) + { + return; + } + + try + { + var safetensorPath = CheckpointFile.GetFullPath(modelsDir); + + var metadata = await SafetensorMetadata.ParseAsync(safetensorPath); + + CheckpointFile.SafetensorMetadataParsed = true; + CheckpointFile.SafetensorMetadata = metadata; + } + catch (Exception ex) + { + logger.LogWarning(ex, "Failed to parse safetensor metadata"); + return; + } + } + + if (!CheckpointFile.SafetensorMetadataParsed) + { + return; + } + + var vm = vmFactory.Get(vm => + { + vm.ModelName = CheckpointFile.DisplayModelName; + vm.Metadata = CheckpointFile.SafetensorMetadata; + }); + + var dialog = vm.GetDialog(); + dialog.MinDialogHeight = 800; + dialog.MinDialogWidth = 700; + dialog.CloseButtonText = "Close"; + dialog.DefaultButton = ContentDialogButton.Close; + + await dialog.ShowAsync(); + } + [RelayCommand] private async Task OpenMetadataEditor() { diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/SafetensorMetadataViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SafetensorMetadataViewModel.cs new file mode 100644 index 00000000..3047f5ac --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/SafetensorMetadataViewModel.cs @@ -0,0 +1,27 @@ +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using Injectio.Attributes; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.Views.Dialogs; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; + +[View(typeof(SafetensorMetadataDialog))] +[ManagedService] +[RegisterSingleton] +public partial class SafetensorMetadataViewModel : ContentDialogViewModelBase +{ + [ObservableProperty] + private string? modelName; + + [ObservableProperty] + private SafetensorMetadata? metadata; + + [RelayCommand] + public void CopyTagToClipboard(string tag) + { + App.Clipboard?.SetTextAsync(tag); + } +} diff --git a/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml b/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml index 82b60b94..294de643 100644 --- a/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml +++ b/StabilityMatrix.Avalonia/Views/CheckpointsPage.axaml @@ -489,6 +489,11 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/SafetensorMetadataDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/SafetensorMetadataDialog.axaml.cs new file mode 100644 index 00000000..0318da8f --- /dev/null +++ b/StabilityMatrix.Avalonia/Views/Dialogs/SafetensorMetadataDialog.axaml.cs @@ -0,0 +1,19 @@ +using Avalonia.Controls; +using Avalonia.Markup.Xaml; +using Injectio.Attributes; + +namespace StabilityMatrix.Avalonia.Views.Dialogs; + +[RegisterTransient] +public partial class SafetensorMetadataDialog : UserControl +{ + public SafetensorMetadataDialog() + { + InitializeComponent(); + } + + private void InitializeComponent() + { + AvaloniaXamlLoader.Load(this); + } +} diff --git a/StabilityMatrix.Core/Helper/Compat.cs b/StabilityMatrix.Core/Helper/Compat.cs index 9e4b352f..1a619bab 100644 --- a/StabilityMatrix.Core/Helper/Compat.cs +++ b/StabilityMatrix.Core/Helper/Compat.cs @@ -189,7 +189,11 @@ public static string GetExecutableName() var appImage = Environment.GetEnvironmentVariable("APPIMAGE"); if (string.IsNullOrEmpty(appImage)) { +#if DEBUG + return "DEBUG_NOT_RUNNING_IN_APPIMAGE"; +#else throw new Exception("Could not find APPIMAGE environment variable"); +#endif } return Path.GetFileName(appImage); } diff --git a/StabilityMatrix.Core/Models/Database/LocalModelFile.cs b/StabilityMatrix.Core/Models/Database/LocalModelFile.cs index bd648090..92ae245a 100644 --- a/StabilityMatrix.Core/Models/Database/LocalModelFile.cs +++ b/StabilityMatrix.Core/Models/Database/LocalModelFile.cs @@ -121,7 +121,10 @@ public override int GetHashCode() /// /// Blake3 hash of the file. /// - public string? HashBlake3 => ConnectedModelInfo?.Hashes.BLAKE3; + public string? HashBlake3 => ConnectedModelInfo?.Hashes?.BLAKE3; + + [BsonIgnore] + public bool IsSafetensorFile => Path.GetExtension(RelativePath) == ".safetensors"; [BsonIgnore] public string? PreviewImageFullPathGlobal => @@ -151,6 +154,12 @@ public override int GetHashCode() [MemberNotNullWhen(true, nameof(ConnectedModelInfo))] public bool HasCivitMetadata => HasConnectedModel && ConnectedModelInfo.ModelId != null; + [BsonIgnore] + public SafetensorMetadata? SafetensorMetadata { get; set; } + + [BsonIgnore] + public bool SafetensorMetadataParsed { get; set; } + public string GetFullPath(string rootModelDirectory) { return Path.Combine(rootModelDirectory, RelativePath); diff --git a/StabilityMatrix.Core/Models/SafetensorMetadata.cs b/StabilityMatrix.Core/Models/SafetensorMetadata.cs new file mode 100644 index 00000000..870f5687 --- /dev/null +++ b/StabilityMatrix.Core/Models/SafetensorMetadata.cs @@ -0,0 +1,225 @@ +using System.Buffers; +using System.Buffers.Binary; +using System.Text.Json; + +namespace StabilityMatrix.Core.Models; + +public record SafetensorMetadata +{ + // public string? NetworkModule { get; init; } + // public string? ModelSpecArchitecture { get; init; } + + public List? TagFrequency { get; init; } + + public required List OtherMetadata { get; init; } + + /// + /// Tries to parse the metadata from a SafeTensor file. + /// + /// Path to the SafeTensor file. + /// The parsed metadata. Can be if the file does not contain metadata. + public static async Task ParseAsync(string safetensorPath) + { + using var stream = new FileStream(safetensorPath, FileMode.Open, FileAccess.Read, FileShare.Read); + return await ParseAsync(stream); + } + + /// + /// Tries to parse the metadata from a SafeTensor file. + /// + /// Stream to the SafeTensor file. + /// The parsed metadata. Can be if the file does not contain metadata. + public static async Task ParseAsync(Stream safetensorStream) + { + // 8 bytes unsigned little-endian 64-bit integer + // 1 byte start of JSON object '{' + Memory buffer = new byte[9]; + await safetensorStream.ReadExactlyAsync(buffer).ConfigureAwait(false); + var span = buffer.Span; + + const ulong MAX_ALLOWED_JSON_LENGTH = 100 * 1024 * 1024; // 100 MB + var jsonLength = BinaryPrimitives.ReadUInt64LittleEndian(span); + if (jsonLength > MAX_ALLOWED_JSON_LENGTH) + { + throw new InvalidDataException("JSON length exceeds the maximum allowed size."); + } + if (span[8] != '{') + { + throw new InvalidDataException("JSON does not start with '{'."); + } + + // Unfornately Utf8JsonReader does not support reading from a stream directly. + // Usually the size of the entire JSON object is less than 500KB, + // using a pooled buffer should reduce the number of large allocations. + var jsonBytes = ArrayPool.Shared.Rent((int)jsonLength); + try + { + // Important: the length of the rented buffer can be larger than jsonLength + // and there can be additional junk data at the end. + + // we already read {, so start from index 1 + jsonBytes[0] = (byte)'{'; + await safetensorStream + .ReadExactlyAsync(jsonBytes, 1, (int)(jsonLength - 1)) + .ConfigureAwait(false); + + // read the JSON with Utf8JsonReader, then only deserialize what we need + // saves us from allocating a bunch of strings then throwing them away + var reader = new Utf8JsonReader(jsonBytes.AsSpan(0, (int)jsonLength)); + + reader.Read(); + if (reader.TokenType != JsonTokenType.StartObject) + { + // expecting a JSON object + throw new InvalidDataException("JSON does not start with '{'."); + } + + while (reader.Read()) + { + // for each property in the object + if (reader.TokenType == JsonTokenType.EndObject) + { + // end of the object, no "__metadata__" found + // return true to indicate that we successfully read the JSON + // but it does not contain metadata + return null; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + // expecting a property name + throw new InvalidDataException( + $"Invalid metadata JSON, expected property name but got {reader.TokenType}." + ); + } + + if (reader.ValueTextEquals("__metadata__")) + { + if (JsonSerializer.Deserialize>(ref reader) is { } dict) + { + return FromDictionary(dict); + } + + // got null from Deserialize + throw new InvalidDataException("Failed to deserialize metadata."); + } + else + { + // skip the property value + reader.Skip(); + } + } + // should not reach here, json is malformed + throw new InvalidDataException("Invalid metadata JSON."); + } + finally + { + ArrayPool.Shared.Return(jsonBytes); + } + } + + private static readonly HashSet MetadataKeys = + [ + // "ss_network_module", + // "modelspec.architecture", + "ss_tag_frequency", + ]; + + internal static SafetensorMetadata FromDictionary(Dictionary metadataDict) + { + // equivalent to the following code, rewitten manually for performance + // otherMetadata = metadataDict + // .Where(kv => !MetadataKeys.Contains(kv.Key)) + // .Select(kv => new Metadata(kv.Key, kv.Value)) + // .OrderBy(x => x.Name) + // .ToList(); + var otherMetadata = new List(metadataDict.Count); + foreach (var kv in metadataDict) + { + if (MetadataKeys.Contains(kv.Key)) + { + continue; + } + + otherMetadata.Add(new Metadata(kv.Key, kv.Value)); + } + otherMetadata.Sort((x, y) => string.Compare(x.Name, y.Name, StringComparison.Ordinal)); + + var metadata = new SafetensorMetadata + { + // NetworkModule = metadataDict.GetValueOrDefault("ss_network_module"), + // ModelSpecArchitecture = metadataDict.GetValueOrDefault("modelspec.architecture"), + OtherMetadata = otherMetadata + }; + + if (metadataDict.TryGetValue("ss_tag_frequency", out var tagFrequencyJson)) + { + try + { + // ss_tag_frequency example: + // { "some_name": {"tag1": 5, "tag2": 10}, "another_name": {"tag1": 3, "tag3": 1} } + // we flatten the dictionary of dictionaries into a single dictionary + + var tagFrequencyDict = new Dictionary(); + + var doc = JsonDocument.Parse(tagFrequencyJson); + var root = doc.RootElement; + if (root.ValueKind == JsonValueKind.Object) + { + foreach (var property in root.EnumerateObject()) + { + var tags = property.Value; + if (tags.ValueKind != JsonValueKind.Object) + { + continue; + } + + foreach (var tagProperty in tags.EnumerateObject()) + { + var tagName = tagProperty.Name; + + if ( + string.IsNullOrEmpty(tagName) + || tagProperty.Value.ValueKind != JsonValueKind.Number + ) + { + continue; + } + + var count = tagProperty.Value.GetInt32(); + if (!tagFrequencyDict.TryAdd(tagName, count)) + { + // tag already exists, increment the count + tagFrequencyDict[tagName] += count; + } + } + } + } + + // equivalent to the following code, rewitten manually for performance + // tagFrequency = tagFrequencyDict + // .Select(kv => new Tag(kv.Key, kv.Value)) + // .OrderByDescending(x => x.Frequency) + // .ToList(); + var tagFrequency = new List(tagFrequencyDict.Count); + foreach (var kv in tagFrequencyDict) + { + tagFrequency.Add(new Tag(kv.Key, kv.Value)); + } + tagFrequency.Sort((x, y) => y.Frequency.CompareTo(x.Frequency)); + + metadata = metadata with { TagFrequency = tagFrequency }; + } + catch (Exception) + { + // ignore + } + } + + return metadata; + } + + public readonly record struct Tag(string Name, int Frequency); + + public readonly record struct Metadata(string Name, string Value); +} diff --git a/StabilityMatrix.Core/Services/ModelIndexService.cs b/StabilityMatrix.Core/Services/ModelIndexService.cs index c87372a6..53943a5c 100644 --- a/StabilityMatrix.Core/Services/ModelIndexService.cs +++ b/StabilityMatrix.Core/Services/ModelIndexService.cs @@ -28,6 +28,7 @@ public partial class ModelIndexService : IModelIndexService private readonly ISettingsManager settingsManager; private readonly ILiteDbContext liteDbContext; private readonly ModelFinder modelFinder; + private readonly SemaphoreSlim safetensorMetadataParseLock = new(1, 1); private DateTimeOffset lastUpdateCheck = DateTimeOffset.MinValue; @@ -543,6 +544,86 @@ await liteDbContext ); EventManager.Instance.OnModelIndexChanged(); + + Task.Run(LoadSafetensorMetadataAsync) + .SafeFireAndForget(ex => + { + logger.LogError(ex, "Error loading safetensor metadata"); + }); + } + + private async Task LoadSafetensorMetadataAsync() + { + if (!settingsManager.IsLibraryDirSet) + { + logger.LogTrace("Safetensor metadata loading skipped, library directory not set"); + return; + } + + if (new DirectoryPath(settingsManager.ModelsDirectory) is not { Exists: true } modelsDir) + { + logger.LogTrace("Safetensor metadata loading skipped, model directory does not exist"); + return; + } + + await safetensorMetadataParseLock.WaitAsync().ConfigureAwait(false); + try + { + var stopwatch = Stopwatch.StartNew(); + var readSuccess = 0; + var readFail = 0; + logger.LogInformation("Loading safetensor metadata..."); + + var models = ModelIndex + .Values.SelectMany(x => x) + .Where(m => !m.SafetensorMetadataParsed && m.RelativePath.EndsWith(".safetensors")); + + await Parallel + .ForEachAsync( + models, + new ParallelOptions + { + MaxDegreeOfParallelism = Math.Max(1, Math.Min(Environment.ProcessorCount / 2, 6)), + TaskScheduler = TaskScheduler.Default, + }, + async (model, token) => + { + if (model.SafetensorMetadataParsed) + return; + + if (!model.RelativePath.EndsWith(".safetensors")) + return; + + try + { + var safetensorPath = model.GetFullPath(modelsDir); + var metadata = await SafetensorMetadata + .ParseAsync(safetensorPath) + .ConfigureAwait(false); + model.SafetensorMetadata = metadata; + model.SafetensorMetadataParsed = true; + + Interlocked.Increment(ref readSuccess); + } + catch + { + Interlocked.Increment(ref readFail); + } + } + ) + .ConfigureAwait(false); + + logger.LogInformation( + "Loaded safetensor metadata for {Success} models, failed to load for {Fail} models in {Time:F2}ms", + readSuccess, + readFail, + stopwatch.Elapsed.TotalMilliseconds + ); + } + finally + { + safetensorMetadataParseLock.Release(); + } } /// @@ -672,7 +753,7 @@ private static HashSet CollectModelHashes(IEnumerable mo var hashes = new HashSet(); foreach (var model in models) { - if (model.ConnectedModelInfo?.Hashes.BLAKE3 is { } hashBlake3) + if (model.ConnectedModelInfo?.Hashes?.BLAKE3 is { } hashBlake3) { hashes.Add(hashBlake3); } diff --git a/StabilityMatrix.Tests/Models/SafetensorMetadataTests.cs b/StabilityMatrix.Tests/Models/SafetensorMetadataTests.cs new file mode 100644 index 00000000..780ff015 --- /dev/null +++ b/StabilityMatrix.Tests/Models/SafetensorMetadataTests.cs @@ -0,0 +1,41 @@ +using System.Buffers.Binary; +using System.Text; +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Tests.Models; + +[TestClass] +public class SafetensorMetadataTests +{ + [TestMethod] + public async Task TestParseStreamAsync() + { + const string SOURCE_JSON = """ +{ +"anything":[1,2,3,4,"",{ "a": 1, "b": 2, "c": 3 }], +"__metadata__":{"ss_network_module":"some network module","modelspec.architecture":"some architecture", + "ss_tag_frequency":"{\"aaa\":{\"tag1\":59,\"tag2\":2},\"bbb\":{\"tag1\":4,\"tag3\":1}}" }, +"someotherdata":{ "a": 1, "b": 2, "c": 3 } +} +"""; + + var stream = new MemoryStream(); + Span buffer = stackalloc byte[8]; + BinaryPrimitives.WriteUInt64LittleEndian(buffer, (ulong)SOURCE_JSON.Length); + stream.Write(buffer); + stream.Write(Encoding.UTF8.GetBytes(SOURCE_JSON)); + stream.Position = 0; + + var metadata = await SafetensorMetadata.ParseAsync(stream); + + // Assert.AreEqual("some network module", metadata.NetworkModule); + // Assert.AreEqual("some architecture", metadata.ModelSpecArchitecture); + + Assert.IsNotNull(metadata); + Assert.IsNotNull(metadata.TagFrequency); + CollectionAssert.AreEqual( + new List { new("tag1", 63), new("tag2", 2), new("tag3", 1) }, + metadata.TagFrequency + ); + } +}