Skip to content

Commit

Permalink
Aggregate PriorityMux and Mux1H Seq size errors (#4609)
Browse files Browse the repository at this point in the history
This requires adding source locators to PriorityMux and Mux1H which
requires SourceInfoTransform macros in Scala 2 and thus splitting the
Scala 2 / Scala 3 public interfaces.

Because macro applications do not support named arguments in Scala 2,
this is an API change.
  • Loading branch information
jackkoenig authored Jan 10, 2025
1 parent 5d68bea commit 0d73982
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ class SourceInfoTransform(val c: Context) extends AutoSourceTransform {
def inNEnUseDualPortSramNameArg(in: c.Tree, n: c.Tree, en: c.Tree, useDualPortSram: c.Tree, name: c.Tree): c.Tree = {
q"$thisObj.$doFuncTerm($in, $n, $en, $useDualPortSram, $name)($implicitSourceInfo)"
}

def selInArg(sel: c.Tree, in: c.Tree): c.Tree = {
q"$thisObj.$doFuncTerm($sel, $in)($implicitSourceInfo)"
}
}

// Workaround for https://github.com/sbt/sbt/issues/3966
Expand Down
56 changes: 55 additions & 1 deletion src/main/scala-2/chisel3/util/Mux.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,63 @@ package chisel3.util

import chisel3._
import chisel3.experimental.SourceInfo
import chisel3.internal.sourceinfo.MuxLookupTransform
import chisel3.internal.sourceinfo.{MuxLookupTransform, SourceInfoTransform}
import scala.language.experimental.macros

/** Builds a Mux tree out of the input signal vector using a one hot encoded
* select signal. Returns the output of the Mux tree.
*
* @example {{{
* val hotValue = chisel3.util.Mux1H(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
*
* @note results unspecified unless exactly one select signal is high
*/
object Mux1H extends Mux1HImpl {
def apply[T <: Data](sel: Seq[Bool], in: Seq[T]): T = macro SourceInfoTransform.selInArg
def do_apply[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply[T <: Data](in: Iterable[(Bool, T)]): T = macro SourceInfoTransform.inArg
def do_apply[T <: Data](in: Iterable[(Bool, T)])(implicit sourceInfo: SourceInfo): T = _applyImpl(in)

def apply[T <: Data](sel: UInt, in: Seq[T]): T = macro SourceInfoTransform.selInArg
def do_apply[T <: Data](sel: UInt, in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply(sel: UInt, in: UInt): Bool = macro SourceInfoTransform.selInArg
def do_apply(sel: UInt, in: UInt)(implicit sourceInfo: SourceInfo): Bool = _applyImpl(sel, in)

}

/** Builds a Mux tree under the assumption that multiple select signals
* can be enabled. Priority is given to the first select signal.
*
* @example {{{
* val hotValue = chisel3.util.PriorityMux(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
* Returns the output of the Mux tree.
*/
object PriorityMux extends PriorityMuxImpl {

def apply[T <: Data](in: Seq[(Bool, T)]): T = macro SourceInfoTransform.inArg
def do_apply[T <: Data](in: Seq[(Bool, T)])(implicit sourceInfo: SourceInfo): T = _applyImpl(in)

def apply[T <: Data](sel: Seq[Bool], in: Seq[T]): T = macro SourceInfoTransform.selInArg
def do_apply[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply[T <: Data](sel: Bits, in: Seq[T]): T = macro SourceInfoTransform.selInArg
def do_apply[T <: Data](sel: Bits, in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)
}

/** Creates a cascade of n Muxs to search for a key value. The Selector may be a UInt or an EnumType.
*
* @example {{{
Expand Down
47 changes: 47 additions & 0 deletions src/main/scala-3/chisel3/util/Mux.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,53 @@ package chisel3.util
import chisel3._
import chisel3.experimental.SourceInfo

/** Builds a Mux tree out of the input signal vector using a one hot encoded
* select signal. Returns the output of the Mux tree.
*
* @example {{{
* val hotValue = chisel3.util.Mux1H(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
*
* @note results unspecified unless exactly one select signal is high
*/
object Mux1H extends Mux1HImpl {

def apply[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply[T <: Data](in: Iterable[(Bool, T)])(implicit sourceInfo: SourceInfo): T = _applyImpl(in)

def apply[T <: Data](sel: UInt, in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply(sel: UInt, in: UInt)(implicit sourceInfo: SourceInfo): Bool = _applyImpl(sel, in)
}

/** Builds a Mux tree under the assumption that multiple select signals
* can be enabled. Priority is given to the first select signal.
*
* @example {{{
* val hotValue = chisel3.util.PriorityMux(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
* Returns the output of the Mux tree.
*/
object PriorityMux extends PriorityMuxImpl {

def apply[T <: Data](in: Seq[(Bool, T)])(implicit sourceInfo: SourceInfo): T = _applyImpl(in)

def apply[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply[T <: Data](sel: Bits, in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)
}

/** Creates a cascade of n Muxs to search for a key value. The Selector may be a UInt or an EnumType.
*
* @example {{{
Expand Down
71 changes: 27 additions & 44 deletions src/main/scala/chisel3/util/MuxImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,55 +7,38 @@ package chisel3.util

import chisel3._
import chisel3.experimental.SourceInfo
import chisel3.internal.Builder

/** Builds a Mux tree out of the input signal vector using a one hot encoded
* select signal. Returns the output of the Mux tree.
*
* @example {{{
* val hotValue = chisel3.util.Mux1H(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
*
* @note results unspecified unless exactly one select signal is high
*/
object Mux1H {
def apply[T <: Data](sel: Seq[Bool], in: Seq[T]): T = {
require(sel.size == in.size, s"Mux1H: input Seqs must have the same length, got sel ${sel.size} and in ${in.size}")
apply(sel.zip(in))
private[chisel3] trait Mux1HImpl {
protected def _applyImpl[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = {
if (sel.size != in.size) {
Builder.error(s"Mux1H: input Seqs must have the same length, got sel ${sel.size} and in ${in.size}")
}
_applyImpl(sel.zip(in))
}
def apply[T <: Data](in: Iterable[(Bool, T)]): T = SeqUtils.oneHotMux(in)
def apply[T <: Data](sel: UInt, in: Seq[T]): T =
apply((0 until in.size).map(sel(_)), in)
def apply(sel: UInt, in: UInt): Bool = (sel & in).orR

protected def _applyImpl[T <: Data](in: Iterable[(Bool, T)])(implicit sourceInfo: SourceInfo): T =
SeqUtils.oneHotMux(in)

protected def _applyImpl[T <: Data](sel: UInt, in: Seq[T])(implicit sourceInfo: SourceInfo): T =
_applyImpl((0 until in.size).map(sel(_)), in)

protected def _applyImpl(sel: UInt, in: UInt)(implicit sourceInfo: SourceInfo): Bool = (sel & in).orR
}

/** Builds a Mux tree under the assumption that multiple select signals
* can be enabled. Priority is given to the first select signal.
*
* @example {{{
* val hotValue = chisel3.util.PriorityMux(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
* Returns the output of the Mux tree.
*/
object PriorityMux {
def apply[T <: Data](in: Seq[(Bool, T)]): T = SeqUtils.priorityMux(in)
def apply[T <: Data](sel: Seq[Bool], in: Seq[T]): T = {
require(
sel.size == in.size,
s"PriorityMux: input Seqs must have the same length, got sel ${sel.size} and in ${in.size}"
)
apply(sel.zip(in))
private[chisel3] trait PriorityMuxImpl {

protected def _applyImpl[T <: Data](in: Seq[(Bool, T)]): T = SeqUtils.priorityMux(in)

protected def _applyImpl[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = {
if (sel.size != in.size) {
Builder.error(s"PriorityMux: input Seqs must have the same length, got sel ${sel.size} and in ${in.size}")
}
_applyImpl(sel.zip(in))
}
def apply[T <: Data](sel: Bits, in: Seq[T]): T = apply((0 until in.size).map(sel(_)), in)

protected def _applyImpl[T <: Data](sel: Bits, in: Seq[T])(implicit sourceInfo: SourceInfo): T =
_applyImpl((0 until in.size).map(sel(_)), in)
}

private[chisel3] trait MuxLookupImpl {
Expand Down
18 changes: 15 additions & 3 deletions src/test/scala/chiselTests/OneHotMuxSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package chiselTests
import chisel3._
import chisel3.testers.BasicTester
import chisel3.util.{Mux1H, UIntToOH}
import _root_.circt.stage.ChiselStage.emitCHIRRTL
import org.scalatest._
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers
Expand Down Expand Up @@ -32,12 +33,23 @@ class OneHotMuxSpec extends AnyFreeSpec with Matchers with ChiselRunners {
e.getMessage should include("Mux1H must have a non-empty argument")
}
"Mux1H should give a error when given different size Seqs" in {
val e = intercept[IllegalArgumentException] {
Mux1H(Seq(true.B, true.B), Seq(1.U, 2.U, 3.U))
val e = intercept[ChiselException] {
emitCHIRRTL(
new RawModule {
Mux1H(Seq(true.B, false.B), Seq(1.U, 2.U, 3.U))
},
args = Array("--throw-on-first-error")
)
}
e.getMessage should include("OneHotMuxSpec.scala") // Make sure source locator comes from this file
e.getMessage should include("Mux1H: input Seqs must have the same length, got sel 2 and in 3")
}

// The input bitvector is sign extended to the width of the sequence
"Mux1H should NOT error when given mismatched selector width and Seq size" in {
emitCHIRRTL(new RawModule {
Mux1H("b10".U(2.W), Seq(1.U, 2.U, 3.U))
})
}
}

class SimpleOneHotTester extends BasicTester {
Expand Down
17 changes: 15 additions & 2 deletions src/test/scala/chiselTests/util/PriorityMuxSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,22 @@ class PriorityMuxSpec extends ChiselFlatSpec {
}

it should "give a error when given different size Seqs" in {
val e = intercept[IllegalArgumentException] {
PriorityMux(Seq(true.B, true.B), Seq(1.U, 2.U, 3.U))
val e = intercept[ChiselException] {
emitCHIRRTL(
new RawModule {
PriorityMux(Seq(true.B, false.B), Seq(1.U, 2.U, 3.U))
},
args = Array("--throw-on-first-error")
)
}
e.getMessage should include("PriorityMuxSpec.scala") // Make sure source locator comes from this file
e.getMessage should include("PriorityMux: input Seqs must have the same length, got sel 2 and in 3")
}

// The input bitvector is sign extended to the width of the sequence
it should "NOT error when given mismatched selector width and Seq size" in {
emitCHIRRTL(new RawModule {
PriorityMux("b10".U(2.W), Seq(1.U, 2.U, 3.U))
})
}
}

0 comments on commit 0d73982

Please sign in to comment.