Skip to content

Commit

Permalink
Different fix if single word tokens > maximumTokens to match huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
georg-jung committed Nov 30, 2023
1 parent ca7856e commit 4a6268a
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions src/FastBertTokenizer/BertTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,23 +179,25 @@ internal TokenizedRange<TKey> TokenizeBatchElement<TKey>(
foreach (var pivot in new PreTokenizingEnumerator(input, _lowercaseInput, _normalization, inputOffset))
{
lastTokenizedWordStartIndex = pivot.SegmentStartIndex;
var added = TokenizeSubword(pivot.Segment, inputIds.Slice(inputIdCnt, inputIds.Length - inputIdCnt));
var offset = 0;
var added = TokenizeSubword(pivot.Segment, inputIds.Slice(inputIdCnt, inputIds.Length - inputIdCnt), ref offset);

// subword was/needs to be cut off because it was to long
if (inputIdCnt + added + 1 > maximumTokens)
{
if (inputIdCnt == 0 || (inputIdCnt == 1 && emitClsToken))
{
// The word didn't fit even though this is the first word we're trying to tokenize.
// We'll need to emit [UNK] instead of the word because this would otherwise lead to
// an endless loop if we try to tokenize the "remainder" of this input. This is does
// make sense too because this is probably some strange sequence of characters or we
// chose a very small maximumTokens value.
// We'll just pretend the word ended here at the point where it didn't fit anymore.
// If the remainder of this input is also tokenized this will incorrectly assume that
// a new word starts at the beginning of the remainder.
// This is quite an edge case because it just happens if one single word tokenizes to
// more tokens then maximumTokens. This is very unlikely to happen in practice.
inputIds[inputIdCnt] = _unk.Id;
inputIdCnt++;
continue;
// more tokens then maximumTokens. This is very unlikely to happen in practice. We
// probably tokenize some strange sequence of characters here or a very small value
// was chosen for maximumTokens.
inputIdCnt = maximumTokens - 1; // leave one out for the final [SEP] token
lastTokenizedWordStartIndex = pivot.SegmentStartIndex + offset;
break;
}

moreRemainingInput = true;
Expand Down Expand Up @@ -370,9 +372,9 @@ bool NeedsRemoval(ReadOnlySpan<char> formD)
return new string(span.Slice(0, i)).Normalize(targetNf);
}

private int TokenizeSubword(ReadOnlySpan<char> word, Span<long> tokenIdSink)
private int TokenizeSubword(ReadOnlySpan<char> word, Span<long> tokenIdSink, ref int offset)
{
int OnUnknown(ReadOnlySpan<char> word, Span<long> tokenIdSink)
int OnUnknown(ReadOnlySpan<char> word, Span<long> tokenIdSink, ref int offset)
{
if (RemoveControlAndReplacement(word, out var withoutControl))
{
Expand All @@ -381,7 +383,7 @@ int OnUnknown(ReadOnlySpan<char> word, Span<long> tokenIdSink)
return 0;
}

return TokenizeSubword(withoutControl, tokenIdSink);
return TokenizeSubword(withoutControl, tokenIdSink, ref offset);
}

// Normalize and IsNormalized for ReadOnlySpan<char> is not yet implemented:
Expand All @@ -392,16 +394,17 @@ int OnUnknown(ReadOnlySpan<char> word, Span<long> tokenIdSink)
var wordStr = word.ToString();
if (!wordStr.IsNormalized(_normalization))
{
return TokenizeSubword(wordStr.Normalize(_normalization), tokenIdSink);
return TokenizeSubword(wordStr.Normalize(_normalization), tokenIdSink, ref offset);
}

var withoutDiacrit = RemoveDiacritics(wordStr, _normalization);
if (!MemoryExtensions.Equals(withoutDiacrit, word, StringComparison.Ordinal))
{
return TokenizeSubword(withoutDiacrit.AsSpan(), tokenIdSink);
return TokenizeSubword(withoutDiacrit.AsSpan(), tokenIdSink, ref offset);
}

tokenIdSink[0] = _unk.Id;
offset += word.Length;
return 1;
}

Expand All @@ -424,13 +427,14 @@ int OnUnknown(ReadOnlySpan<char> word, Span<long> tokenIdSink)

if (id == -1)
{
return OnUnknown(word, tokenIdSink);
return OnUnknown(word, tokenIdSink, ref offset);
}

tokenIdSink[0] = id;
cnt++;

var remaining = word.Slice(prefix.Length);
offset += prefix.Length;
while (remaining.Length > 0 && cnt < tokenIdSink.Length)
{
var suffix = remaining;
Expand All @@ -450,12 +454,13 @@ int OnUnknown(ReadOnlySpan<char> word, Span<long> tokenIdSink)

if (id == -1)
{
return OnUnknown(word, tokenIdSink);
return OnUnknown(word, tokenIdSink, ref offset);
}

tokenIdSink[cnt] = id;
cnt++;
remaining = remaining.Slice(suffix.Length);
offset += suffix.Length;
}

return cnt;
Expand Down

0 comments on commit 4a6268a

Please sign in to comment.