Skip to content

Commit

Permalink
Unit test for entropy round trips
Browse files Browse the repository at this point in the history
  • Loading branch information
ynse01 committed Nov 27, 2024
1 parent 473adab commit 7edc7dc
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 70 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) Six Labors.
// Licensed under the Six Labors Split License.

using System;

namespace SixLabors.ImageSharp.Formats.Heif.Av1.Entropy;

internal static class Av1DefaultDistributions
Expand Down Expand Up @@ -1610,6 +1608,7 @@ internal static class Av1DefaultDistributions
],
];

// SVT: av1_default_txb_skip_cdfs
private static Av1Distribution[][][] TransformBlockSkip =>
[
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ namespace SixLabors.ImageSharp.Formats.Heif.Av1.Entropy;

internal static class Av1SymbolContextHelper
{
public static readonly int[][] ExtendedTransformIndices = [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 3, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 5, 6, 4, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0],
[3, 4, 5, 8, 6, 7, 9, 10, 11, 0, 1, 2, 0, 0, 0, 0],
[7, 8, 9, 12, 10, 11, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6],
];

public static readonly int[] EndOfBlockOffsetBits = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
public static readonly int[] EndOfBlockGroupStart = [0, 1, 2, 3, 5, 9, 17, 33, 65, 129, 257, 513];
private static readonly int[] TransformCountInSet = [1, 2, 5, 7, 12, 16];
Expand Down
54 changes: 52 additions & 2 deletions src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ namespace SixLabors.ImageSharp.Formats.Heif.Av1.Entropy;

internal ref struct Av1SymbolDecoder
{
private static readonly int[][] ExtendedTransformIndicesInverse = [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 10, 11, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 10, 11, 0, 1, 2, 4, 5, 3, 6, 7, 8, 0, 0, 0, 0],
[9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 4, 5, 3, 6, 7, 8],
];

private static readonly int[] IntraModeContext = [0, 1, 2, 3, 4, 4, 4, 4, 3, 0, 1, 2, 0];
private static readonly int[] AlphaVContexts = [-1, 0, 3, -1, 1, 4, -1, 2, 5];

Expand All @@ -33,7 +42,8 @@ internal ref struct Av1SymbolDecoder
private readonly Av1Distribution[][][] endOfBlockExtra;
private readonly Av1Distribution chromeForLumaSign = Av1DefaultDistributions.ChromeForLumaSign;
private readonly Av1Distribution[] chromeForLumaAlpha = Av1DefaultDistributions.ChromeForLumaAlpha;
private Configuration configuration;
private readonly Av1Distribution[][][] intraExtendedTransform = Av1DefaultDistributions.IntraExtendedTransform;
private readonly Configuration configuration;
private Av1SymbolReader reader;

public Av1SymbolDecoder(Configuration configuration, Span<byte> tileData, int qIndex)
Expand Down Expand Up @@ -180,6 +190,46 @@ public Av1TransformSize ReadTransformSize(Av1BlockSize blockSize, int context)
return transformSize;
}

public Av1TransformType ReadTransformType(
Av1TransformSize transformSize,
bool useReducedTransformSet,
bool useFilterIntra,
int baseQIndex,
Av1FilterIntraMode filterIntraMode,
Av1PredictionMode intraDirection)
{
Av1TransformType transformType = Av1TransformType.DctDct;

/*
// No need to read transform type if block is skipped.
if (mbmi.Skip ||
svt_aom_seg_feature_active(&parse_ctxt->frame_header->segmentation_params, mbmi->segment_id, SEG_LVL_SKIP))
return;
*/

// Ignoring INTER blocks here, as these should not end up here.
// int inter_block = is_inter_block_dec(mbmi);
Av1TransformSetType tx_set_type = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet);
if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSize, useReducedTransformSet) > 1 && baseQIndex > 0)
{
int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSize, useReducedTransformSet);

// eset == 0 should correspond to a set with only DCT_DCT and
// there is no need to read the tx_type
Guard.IsFalse(extendedSet == 0, nameof(extendedSet), string.Empty);

Av1TransformSize squareTransformSize = transformSize.GetSquareSize();
Av1PredictionMode intraMode = useFilterIntra
? filterIntraMode.ToIntraDirection()
: intraDirection;
ref Av1SymbolReader r = ref this.reader;
int symbol = r.ReadSymbol(this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraMode]);
transformType = (Av1TransformType)ExtendedTransformIndicesInverse[(int)tx_set_type][symbol];
}

return transformType;
}

public bool ReadTransformBlockSkip(Av1TransformSize transformSizeContext, int skipContext)
{
ref Av1SymbolReader r = ref this.reader;
Expand Down Expand Up @@ -233,7 +283,7 @@ public int ReadCoefficients(
{
int width = transformSize.GetWidth();
int height = transformSize.GetHeight();
Av1TransformSize transformSizeContext = (Av1TransformSize)(((int)transformSize.GetSquareSize() + ((int)transformSize.GetSquareUpSize() + 1)) >> 1);
Av1TransformSize transformSizeContext = (Av1TransformSize)(((int)transformSize.GetSquareSize() + (int)transformSize.GetSquareUpSize() + 1) >> 1);
Av1PlaneType planeType = (Av1PlaneType)Math.Min(plane, 1);
int culLevel = 0;

Expand Down
96 changes: 46 additions & 50 deletions src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ internal class Av1SymbolEncoder : IDisposable

private readonly Av1Distribution tileIntraBlockCopy = Av1DefaultDistributions.IntraBlockCopy;
private readonly Av1Distribution[] tilePartitionTypes = Av1DefaultDistributions.PartitionTypes;
private readonly Av1Distribution[] skip = Av1DefaultDistributions.Skip;
private readonly Av1Distribution[][] transformBlockSkip;
private readonly Av1Distribution[][][] endOfBlockFlag;
private readonly Av1Distribution[][][] coefficientsBaseRange;
private readonly Av1Distribution[][][] coefficientsBase;
Expand All @@ -36,6 +36,7 @@ internal class Av1SymbolEncoder : IDisposable

public Av1SymbolEncoder(Configuration configuration, int initialSize, int qIndex)
{
this.transformBlockSkip = Av1DefaultDistributions.GetTransformBlockSkip(qIndex);
this.endOfBlockFlag = Av1DefaultDistributions.GetEndOfBlockFlag(qIndex);
this.coefficientsBaseRange = Av1DefaultDistributions.GetCoefficientsBaseRange(qIndex);
this.coefficientsBase = Av1DefaultDistributions.GetCoefficientsBase(qIndex);
Expand All @@ -49,7 +50,7 @@ public Av1SymbolEncoder(Configuration configuration, int initialSize, int qIndex
public void WriteUseIntraBlockCopy(bool value)
{
ref Av1SymbolWriter w = ref this.writer;
w.WriteSymbol(value ? 1 : 0, this.tileIntraBlockCopy);
w.WriteSymbol(value, this.tileIntraBlockCopy);
}

public void WritePartitionType(Av1PartitionType partitionType, int context)
Expand Down Expand Up @@ -82,10 +83,9 @@ public int WriteCoefficients(
Av1TransformType transformType,
int txbIndex, // TODO: Doesn't seem to be used, remove.
Av1PredictionMode intraDirection,
Span<int> coeffBuffer,
Span<int> coefficientBuffer,
Av1ComponentType componentType,
short transformBlockSkipContext,
short dcSignContext,
Av1TransformBlockContext transformBlockContext,
ushort eob,
bool useReducedTransformSet,
int baseQIndex,
Expand All @@ -97,54 +97,53 @@ public int WriteCoefficients(
Av1TransformClass transformClass = transformType.ToClass();
Av1ScanOrder scanOrder = Av1ScanOrderConstants.GetScanOrder(transformSize, transformType);
ReadOnlySpan<short> scan = scanOrder.Scan;
int bwl = transformSize.GetBlockWidthLog2();
int blockWidthLog2 = transformSize.GetBlockWidthLog2();
Av1TransformSize transformSizeContext = (Av1TransformSize)(((int)transformSize.GetSquareSize() + (int)transformSize.GetSquareUpSize() + 1) >> 1);

ref Av1SymbolWriter w = ref this.writer;

Av1LevelBuffer levels = new(this.configuration, new Size(width, height));
Span<sbyte> coeff_contexts = new sbyte[Av1Constants.MaxTransformSize * Av1Constants.MaxTransformSize];
Span<sbyte> coefficientContexts = new sbyte[Av1Constants.MaxTransformSize * Av1Constants.MaxTransformSize];

Guard.MustBeLessThan((int)transformSizeContext, (int)Av1TransformSize.AllSizes, nameof(transformSizeContext));

bool hasEndOfBlock = eob != 0;
this.WriteSkip(!hasEndOfBlock, transformBlockSkipContext);
this.WriteTransformBlockSkip(eob == 0, transformSizeContext, transformBlockContext.SkipContext);

if (eob == 0)
{
return 0;
}

levels.Initialize(coeffBuffer);
levels.Initialize(coefficientBuffer);
if (componentType == Av1ComponentType.Luminance)
{
this.WriteTransformType(transformType, transformSize, useReducedTransformSet, baseQIndex, filterIntraMode, intraDirection);
}

short endOfBlockPosition = Av1SymbolContextHelper.GetEndOfBlockPosition(eob, out int eob_extra);
short endOfBlockPosition = Av1SymbolContextHelper.GetEndOfBlockPosition(eob, out int eobExtra);
this.WriteEndOfBlockFlag(componentType, transformClass, transformSize, endOfBlockPosition);

int eob_offset_bits = Av1SymbolContextHelper.EndOfBlockOffsetBits[endOfBlockPosition];
if (eob_offset_bits > 0)
int eobOffsetBitCount = Av1SymbolContextHelper.EndOfBlockOffsetBits[endOfBlockPosition];
if (eobOffsetBitCount > 0)
{
int eob_shift = eob_offset_bits - 1;
int bit = (eob_extra & (1 << eob_shift)) != 0 ? 1 : 0;
int eobShift = eobOffsetBitCount - 1;
int bit = (eobExtra & (1 << eobShift)) != 0 ? 1 : 0;
w.WriteSymbol(bit, this.endOfBlockExtra[(int)transformSizeContext][(int)componentType][endOfBlockPosition]);
for (int i = 1; i < eob_offset_bits; i++)
for (int i = 1; i < eobOffsetBitCount; i++)
{
eob_shift = eob_offset_bits - 1 - i;
bit = (eob_extra & (1 << eob_shift)) != 0 ? 1 : 0;
eobShift = eobOffsetBitCount - 1 - i;
bit = (eobExtra & (1 << eobShift)) != 0 ? 1 : 0;
w.WriteLiteral((uint)bit, 1);
}
}

Av1SymbolContextHelper.GetNzMapContexts(levels, scan, eob, transformSize, transformClass, coeff_contexts);
Av1SymbolContextHelper.GetNzMapContexts(levels, scan, eob, transformSize, transformClass, coefficientContexts);
int limitedTransformSizeContext = Math.Min((int)transformSizeContext, (int)Av1TransformSize.Size32x32);
for (c = eob - 1; c >= 0; --c)
{
short pos = scan[c];
int v = coeffBuffer[pos];
short coeff_ctx = coeff_contexts[pos];
int v = coefficientBuffer[pos];
short coeff_ctx = coefficientContexts[pos];
int level = Math.Abs(v);

if (c == eob - 1)
Expand All @@ -160,7 +159,7 @@ public int WriteCoefficients(
{
// level is above 1.
int baseRange = level - 1 - Av1Constants.BaseLevelsCount;
int baseRangeContext = Av1SymbolContextHelper.GetBaseRangeContext(levels, pos, bwl, transformClass);
int baseRangeContext = Av1SymbolContextHelper.GetBaseRangeContext(levels, pos, blockWidthLog2, transformClass);
for (int idx = 0; idx < Av1Constants.CoefficientBaseRange; idx += Av1Constants.BaseRangeSizeMinus1)
{
int k = Math.Min(baseRange - idx, Av1Constants.BaseRangeSizeMinus1);
Expand All @@ -179,7 +178,7 @@ public int WriteCoefficients(
for (c = 0; c < eob; ++c)
{
short pos = scan[c];
int v = coeffBuffer[pos];
int v = coefficientBuffer[pos];
int level = Math.Abs(v);
cul_level += level;

Expand All @@ -188,7 +187,7 @@ public int WriteCoefficients(
{
if (c == 0)
{
w.WriteSymbol((int)sign, this.dcSign[(int)componentType][dcSignContext]);
w.WriteSymbol((int)sign, this.dcSign[(int)componentType][transformBlockContext.DcSignContext]);
}
else
{
Expand All @@ -205,15 +204,14 @@ public int WriteCoefficients(
cul_level = Math.Min(Av1Constants.CoefficientContextMask, cul_level);

// DC value
Av1SymbolContextHelper.SetDcSign(ref cul_level, coeffBuffer[0]);
Av1SymbolContextHelper.SetDcSign(ref cul_level, coefficientBuffer[0]);
return cul_level;
}

private void WriteSkip(bool hasEndOfBlock, int context)
internal void WriteTransformBlockSkip(bool skip, Av1TransformSize transformSizeContext, int skipContext)
{
// Has EOB, means we won't skip, negating the logic.
ref Av1SymbolWriter w = ref this.writer;
w.WriteSymbol(hasEndOfBlock ? 0 : 1, this.skip[context]);
w.WriteSymbol(skip, this.transformBlockSkip[(int)transformSizeContext][skipContext]);
}

public IMemoryOwner<byte> Exit()
Expand Down Expand Up @@ -250,7 +248,7 @@ private void WriteGolomb(int level)

for (int j = length - 1; j >= 0; --j)
{
w.WriteLiteral((uint)(x >> j & 0x01), 1);
w.WriteLiteral((uint)((x >> j) & 0x01), 1);
}
}

Expand All @@ -265,7 +263,7 @@ private void WriteEndOfBlockFlag(Av1ComponentType componentType, Av1TransformCla
/// <summary>
/// SVT: av1_write_tx_type
/// </summary>
private void WriteTransformType(
internal void WriteTransformType(
Av1TransformType transformType,
Av1TransformSize transformSize,
bool useReducedTransformSet,
Expand All @@ -277,34 +275,32 @@ private void WriteTransformType(
ref Av1SymbolWriter w = ref this.writer;
if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSize, useReducedTransformSet) > 1 && baseQIndex > 0)
{
Av1TransformSize square_tx_size = transformSize.GetSquareSize();
Guard.MustBeLessThanOrEqualTo((int)square_tx_size, Av1Constants.ExtendedTransformCount, nameof(square_tx_size));
Av1TransformSize squareTransformSize = transformSize.GetSquareSize();
Guard.MustBeLessThanOrEqualTo((int)squareTransformSize, Av1Constants.ExtendedTransformCount, nameof(squareTransformSize));

Av1TransformSetType tx_set_type = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet);
int eset = Av1SymbolContextHelper.GetExtendedTransformSet(transformSize, useReducedTransformSet);
Av1TransformSetType transformSetType = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet);
int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSize, useReducedTransformSet);

// eset == 0 should correspond to a set with only DCT_DCT and there
// is no need to send the tx_type
Guard.MustBeGreaterThan(eset, 0, nameof(eset));
Guard.MustBeGreaterThan(extendedSet, 0, nameof(extendedSet));

// assert(av1_ext_tx_used[tx_set_type][transformType]);
Av1PredictionMode intraMode;
if (filterIntraMode != Av1FilterIntraMode.AllFilterIntraModes)
{
Av1PredictionMode intra_dir;
if (filterIntraMode != Av1FilterIntraMode.AllFilterIntraModes)
{
intra_dir = filterIntraMode.ToIntraDirection();
}
else
{
intra_dir = intraDirection;
}

Guard.MustBeLessThan((int)intra_dir, 13, nameof(intra_dir));
Guard.MustBeLessThan((int)square_tx_size, 4, nameof(square_tx_size));
w.WriteSymbol(
ExtendedTransformIndices[(int)tx_set_type][(int)transformType],
this.intraExtendedTransform[eset][(int)square_tx_size][(int)intra_dir]);
intraMode = filterIntraMode.ToIntraDirection();
}
else
{
intraMode = intraDirection;
}

Guard.MustBeLessThan((int)intraMode, 13, nameof(intraMode));
Guard.MustBeLessThan((int)squareTransformSize, 4, nameof(squareTransformSize));
w.WriteSymbol(
ExtendedTransformIndices[(int)transformSetType][(int)transformType],
this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraMode]);
}
}
}
3 changes: 3 additions & 0 deletions src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ public Av1SymbolWriter(Configuration configuration, int initialSize)

public void Dispose() => this.memory.Dispose();

public void WriteSymbol(bool symbol, Av1Distribution distribution)
=> this.WriteSymbol(symbol ? 1 : 0, distribution);

public void WriteSymbol(int symbol, Av1Distribution distribution)
{
DebugGuard.MustBeGreaterThanOrEqualTo(symbol, 0, nameof(symbol));
Expand Down
6 changes: 3 additions & 3 deletions tests/ImageSharp.Tests/Formats/Heif/Av1/Av1BitStreamTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ public void WriteAsBoolean(bool[] booleans)
}

[Theory]
//[InlineData(6, 4)]
//[InlineData(42, 8)]
//[InlineData(52, 8)]
[InlineData(6, 4)]
[InlineData(42, 8)]
[InlineData(52, 8)]
[InlineData(4050, 16)]
public void WriteAsLiteral(uint value, int bitCount)
{
Expand Down
Loading

0 comments on commit 7edc7dc

Please sign in to comment.