Skip to content

Commit

Permalink
Add support for netstandard2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
georg-jung committed Apr 26, 2024
1 parent 201c23b commit 1e2f8fa
Show file tree
Hide file tree
Showing 23 changed files with 949 additions and 33 deletions.
3 changes: 3 additions & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
<PackageVersion Include="PolySharp" Version="1.14.1" />
<PackageVersion Include="Shouldly" Version="4.2.1" />
<PackageVersion Include="SimpleSIMD" Version="4.6.0" />
<PackageVersion Include="System.Memory" Version="4.5.5" />
<PackageVersion Include="System.Text.Json" Version="8.0.2" />
<PackageVersion Include="System.Threading.Channels" Version="8.0.0" />
<PackageVersion Include="Verify.Xunit" Version="24.1.0" />
<PackageVersion Include="xunit.runner.visualstudio" Version="2.5.8" />
<PackageVersion Include="xunit" Version="2.7.1" />
Expand Down
43 changes: 43 additions & 0 deletions src/FastBertTokenizer.Tests/AsyncBatchEnumeratorVsHuggingface.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ private async Task CompareSimpleWikipediaCorpusAsIsImpl(BertTokenizer uut, Dicti
{
continue;
}
#if NETFRAMEWORK
else if (DidWeSkipSomethingWhereHuggingFaceEmittedUnk(huggF.InputIds.Span, currentInputIds.Span))
{
continue;
}
#endif

var needsToMatchUpToIdx = currentInputIds.Length - 1;

Expand Down Expand Up @@ -306,4 +312,41 @@ private bool DidWeEmitOneUnkWhereHuggingFaceJustSkippedSomething(ReadOnlySpan<lo
(iHF == huggF.Length && iOurs == ours.Length)
|| (iHF + cases == iOurs && iOurs == ours.Length);
}

// Did Hugging Face emitt one ore more [UNK] where we just skipped something?
// This is relevant for .netframework, probably due to it's outdated unicode data.
private bool DidWeSkipSomethingWhereHuggingFaceEmittedUnk(ReadOnlySpan<long> huggF, ReadOnlySpan<long> ours)
{
var skippedHfUnk = 0;
var (iHF, iOurs) = (0, 0);
while (iHF < huggF.Length && iOurs < ours.Length)
{
if (huggF[iHF] == ours[iOurs])
{
iHF++;
iOurs++;
continue;
}

// [SEP] == 102
// Hugging face will be at the end earlier if it emitted [UNK] where we skipped something.
if (skippedHfUnk > 0 && huggF[iHF] == 102)
{
iOurs++;
break;
}

// [UNK] == 100
// Skip [UNK]s in Hugging Face's result
while (iHF < huggF.Length && huggF[iHF] == 100)
{
skippedHfUnk++;
iHF++;
}
}

return
(iHF == huggF.Length && iOurs == ours.Length)
|| (iHF == huggF.Length && iOurs == ours.Length - skippedHfUnk);
}
}
78 changes: 78 additions & 0 deletions src/FastBertTokenizer.Tests/Backports.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) Georg Jung. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Threading.Channels;

namespace FastBertTokenizer;

#if NETFRAMEWORK
internal static class Backports
{
public static void Deconstruct<TKey, TValue>(this KeyValuePair<TKey, TValue> kvp, out TKey key, out TValue value)
{
key = kvp.Key;
value = kvp.Value;
}
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("StyleCop.CSharp.MaintainabilityRules", "SA1402:File may only contain a single type", Justification = "Keep the backports together")]
internal static class ChannelReaderExtensions
{
public static IAsyncEnumerable<T> AsAsyncEnumerable<T>(this ChannelReader<T> reader)
{
if (reader == null)
{
throw new ArgumentNullException(nameof(reader));
}

return new ChannelReaderAsyncEnumerable<T>(reader);
}

// This is probably rather naive and may have bugs but it is for testing only and seems sufficient at this point.
private class ChannelReaderAsyncEnumerable<T> : IAsyncEnumerable<T>
{
private readonly ChannelReader<T> _reader;

public ChannelReaderAsyncEnumerable(ChannelReader<T> reader)
{
_reader = reader;
}

public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new ChannelReaderAsyncEnumerator(_reader, cancellationToken);
}

private class ChannelReaderAsyncEnumerator : IAsyncEnumerator<T>
{
private readonly ChannelReader<T> _reader;
private readonly CancellationToken _cancellationToken;

public ChannelReaderAsyncEnumerator(ChannelReader<T> reader, CancellationToken cancellationToken)
{
_reader = reader;
_cancellationToken = cancellationToken;
Current = default!;
}

public T Current { get; private set; }

public async ValueTask<bool> MoveNextAsync()
{
if (_reader.Completion.IsCompleted)
{
return false;
}

Current = await _reader.ReadAsync();
return true;
}

public ValueTask DisposeAsync()
{
return default;
}
}
}
}
#endif
14 changes: 12 additions & 2 deletions src/FastBertTokenizer.Tests/CompareDifferentEncodeFlavors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ public async Task CompareFlavors(Dictionary<int, string> articles)

channel.Writer.TryComplete().ShouldBeTrue();

await foreach (var batch in _uut.CreateAsyncBatchEnumerator(channel.Reader, 512, 100, stride: 0))
#if NETFRAMEWORK
var asyncEnum = channel.Reader.AsAsyncEnumerable();
#else
var asyncEnum = channel.Reader;
#endif
await foreach (var batch in _uut.CreateAsyncBatchEnumerator(asyncEnum, 512, 100, stride: 0))
{
for (var idx = 0; idx < batch.OutputCorrelation.Length; idx++)
{
Expand Down Expand Up @@ -140,7 +145,12 @@ public async Task CompareFlavors(Dictionary<int, string> articles)

channel2.Writer.TryComplete().ShouldBeTrue();

await foreach (var batch in _uut.CreateAsyncBatchEnumerator(channel2.Reader, 512, 100, stride: 27))
#if NETFRAMEWORK
var asyncEnum2 = channel.Reader.AsAsyncEnumerable();
#else
var asyncEnum2 = channel.Reader;
#endif
await foreach (var batch in _uut.CreateAsyncBatchEnumerator(asyncEnum2, 512, 100, stride: 27))
{
for (var idx = 0; idx < batch.OutputCorrelation.Length; idx++)
{
Expand Down
9 changes: 9 additions & 0 deletions src/FastBertTokenizer.Tests/CompareToHuggingfaceTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ private void CompareImpl(int id, string content)
return;
}

#if NETFRAMEWORK
if (id == 19308)
{
// 19308 "Mali" contains surrogate characters. .NET Framework unicode-categorizes them as "other, not assigned" and thus
// They get removed instead emitting an [UNK]. The difference is probably due to newer a unicode version used by huggingface & modern .NET.
return;
}
#endif

var huggF = RustTokenizer.TokenizeAndGetIds(content, 512);
var ours = _uut.Encode(content, 512, 512);
try
Expand Down
2 changes: 1 addition & 1 deletion src/FastBertTokenizer.Tests/Decode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void DecodeStartingFromSuffix()
var decoded = _uut.Decode(loremIpsum);
decoded.ShouldStartWith("[CLS] lorem ipsum");

long[] startsWithSuffix = loremIpsum[2..];
long[] startsWithSuffix = loremIpsum.AsSpan(2).ToArray();
decoded = _uut.Decode(startsWithSuffix);
decoded.ShouldStartWith("m ipsum");
}
Expand Down
10 changes: 9 additions & 1 deletion src/FastBertTokenizer.Tests/FastBertTokenizer.Tests.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net7.0;net8.0</TargetFrameworks>
<TargetFrameworks>net7.0;net8.0;net48</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>

Expand All @@ -17,6 +17,10 @@
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="PolySharp">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Shouldly" />
<PackageReference Include="Verify.Xunit" />
<PackageReference Include="xunit" />
Expand All @@ -31,6 +35,10 @@
<PackageReference Include="Xunit.SkippableFact" />
</ItemGroup>

<ItemGroup Condition=" '$(TargetFramework)' == 'net48' ">
<PackageReference Include="System.Threading.Channels" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\FastBertTokenizer\FastBertTokenizer.csproj" />
<ProjectReference Include="..\HuggingfaceTokenizer\RustLibWrapper\RustLibWrapper.csproj" />
Expand Down
2 changes: 2 additions & 0 deletions src/FastBertTokenizer.Tests/LoadTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,15 @@ public async Task LoadTokenizerFromInvalidVocabTxtAsync()
await tokenizer.LoadVocabularyAsync("data/invalid/minimal.txt", true);
}

#if !NETFRAMEWORK
[Theory]
[InlineData("bert-base-uncased")]
public async Task LoadFromHuggingFace(string huggingFaceRepo)
{
var tokenizer = new BertTokenizer();
await tokenizer.LoadFromHuggingFaceAsync(huggingFaceRepo);
}
#endif

[Fact]
public async Task PreventLoadAfterLoad()
Expand Down
2 changes: 2 additions & 0 deletions src/FastBertTokenizer.Tests/RestBaaiBgeTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

namespace FastBertTokenizer.Tests;

#if !NETFRAMEWORK
public class RestBaaiBgeTokenizer
{
private readonly string _requestUri;
Expand Down Expand Up @@ -50,3 +51,4 @@ private record class RestModel(
{
}
}
#endif
6 changes: 2 additions & 4 deletions src/FastBertTokenizer.Tests/WikipediaSimpleData.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
// Copyright (c) Georg Jung. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.IO.Compression;
using System.Text.Json;

namespace FastBertTokenizer.Tests;

public static class WikipediaSimpleData
{
private const string Path = "data/wiki-simple.json.br";
private const string Path = "data/wiki-simple.json";
private static readonly Lazy<List<object[]>> _articles = new(GetArticlesImpl);
private static readonly Lazy<Dictionary<int, string>> _articlesDict = new(GetArticlesDictImpl);

Expand All @@ -24,7 +23,6 @@ private static List<object[]> GetArticlesImpl()
private static Dictionary<int, string> GetArticlesDictImpl()
{
using var fs = File.OpenRead(Path);
using var uncompress = new BrotliStream(fs, CompressionMode.Decompress);
return JsonSerializer.Deserialize<Dictionary<int, string>>(uncompress)!;
return JsonSerializer.Deserialize<Dictionary<int, string>>(fs)!;
}
}
Loading

0 comments on commit 1e2f8fa

Please sign in to comment.