Skip to content

Commit

Permalink
Aggregate PriorityMux and Mux1H Seq size errors
Browse files Browse the repository at this point in the history
This requires adding source locators to PriorityMux and Mux1H which
requires SourceInfoTransform macros in Scala 2 and thus splitting the
Scala 2 / Scala 3 public interfaces.

Because macro applications do not support named arguments in Scala 2,
this is an API change.
  • Loading branch information
jackkoenig committed Jan 9, 2025
1 parent 5d68bea commit 2263c56
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ class SourceInfoTransform(val c: Context) extends AutoSourceTransform {
def inNEnUseDualPortSramNameArg(in: c.Tree, n: c.Tree, en: c.Tree, useDualPortSram: c.Tree, name: c.Tree): c.Tree = {
q"$thisObj.$doFuncTerm($in, $n, $en, $useDualPortSram, $name)($implicitSourceInfo)"
}

def selInArg(sel: c.Tree, in: c.Tree): c.Tree = {
q"$thisObj.$doFuncTerm($sel, $in)($implicitSourceInfo)"
}
}

// Workaround for https://github.com/sbt/sbt/issues/3966
Expand Down
56 changes: 55 additions & 1 deletion src/main/scala-2/chisel3/util/Mux.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,63 @@ package chisel3.util

import chisel3._
import chisel3.experimental.SourceInfo
import chisel3.internal.sourceinfo.MuxLookupTransform
import chisel3.internal.sourceinfo.{MuxLookupTransform, SourceInfoTransform}
import scala.language.experimental.macros

/** Builds a Mux tree out of the input signal vector using a one hot encoded
* select signal. Returns the output of the Mux tree.
*
* @example {{{
* val hotValue = chisel3.util.Mux1H(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
*
* @note results unspecified unless exactly one select signal is high
*/
object Mux1H extends Mux1HImpl {
def apply[T <: Data](sel: Seq[Bool], in: Seq[T]): T = macro SourceInfoTransform.selInArg
def do_apply[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply[T <: Data](in: Iterable[(Bool, T)]): T = macro SourceInfoTransform.inArg
def do_apply[T <: Data](in: Iterable[(Bool, T)])(implicit sourceInfo: SourceInfo): T = _applyImpl(in)

def apply[T <: Data](sel: UInt, in: Seq[T]): T = macro SourceInfoTransform.selInArg
def do_apply[T <: Data](sel: UInt, in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply(sel: UInt, in: UInt): Bool = macro SourceInfoTransform.selInArg
def do_apply(sel: UInt, in: UInt)(implicit sourceInfo: SourceInfo): Bool = _applyImpl(sel, in)

}

/** Builds a Mux tree under the assumption that multiple select signals
* can be enabled. Priority is given to the first select signal.
*
* @example {{{
* val hotValue = chisel3.util.PriorityMux(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
* Returns the output of the Mux tree.
*/
object PriorityMux extends PriorityMuxImpl {

def apply[T <: Data](in: Seq[(Bool, T)]): T = macro SourceInfoTransform.inArg
def do_apply[T <: Data](in: Seq[(Bool, T)])(implicit sourceInfo: SourceInfo): T = _applyImpl(in)

def apply[T <: Data](sel: Seq[Bool], in: Seq[T]): T = macro SourceInfoTransform.selInArg
def do_apply[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply[T <: Data](sel: Bits, in: Seq[T]): T = macro SourceInfoTransform.selInArg
def do_apply[T <: Data](sel: Bits, in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)
}

/** Creates a cascade of n Muxs to search for a key value. The Selector may be a UInt or an EnumType.
*
* @example {{{
Expand Down
47 changes: 47 additions & 0 deletions src/main/scala-3/chisel3/util/Mux.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,53 @@ package chisel3.util
import chisel3._
import chisel3.experimental.SourceInfo

/** Builds a Mux tree out of the input signal vector using a one hot encoded
* select signal. Returns the output of the Mux tree.
*
* @example {{{
* val hotValue = chisel3.util.Mux1H(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
*
* @note results unspecified unless exactly one select signal is high
*/
object Mux1H extends Mux1HImpl {

def apply[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply[T <: Data](in: Iterable[(Bool, T)])(implicit sourceInfo: SourceInfo): T = _applyImpl(in)

def apply[T <: Data](sel: UInt, in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply(sel: UInt, in: UInt)(implicit sourceInfo: SourceInfo): Bool = _applyImpl(sel, in)
}

/** Builds a Mux tree under the assumption that multiple select signals
* can be enabled. Priority is given to the first select signal.
*
* @example {{{
* val hotValue = chisel3.util.PriorityMux(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
* Returns the output of the Mux tree.
*/
object PriorityMux extends PriorityMuxImpl {

def apply[T <: Data](in: Seq[(Bool, T)])(implicit sourceInfo: SourceInfo): T = _applyImpl(in)

def apply[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)

def apply[T <: Data](sel: Bits, in: Seq[T])(implicit sourceInfo: SourceInfo): T = _applyImpl(sel, in)
}

/** Creates a cascade of n Muxs to search for a key value. The Selector may be a UInt or an EnumType.
*
* @example {{{
Expand Down
71 changes: 27 additions & 44 deletions src/main/scala/chisel3/util/MuxImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,55 +7,38 @@ package chisel3.util

import chisel3._
import chisel3.experimental.SourceInfo
import chisel3.internal.Builder

/** Builds a Mux tree out of the input signal vector using a one hot encoded
* select signal. Returns the output of the Mux tree.
*
* @example {{{
* val hotValue = chisel3.util.Mux1H(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
*
* @note results unspecified unless exactly one select signal is high
*/
object Mux1H {
def apply[T <: Data](sel: Seq[Bool], in: Seq[T]): T = {
require(sel.size == in.size, s"Mux1H: input Seqs must have the same length, got sel ${sel.size} and in ${in.size}")
apply(sel.zip(in))
private[chisel3] trait Mux1HImpl {
protected def _applyImpl[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = {
if (sel.size != in.size) {
Builder.error(s"Mux1H: input Seqs must have the same length, got sel ${sel.size} and in ${in.size}")
}
_applyImpl(sel.zip(in))
}
def apply[T <: Data](in: Iterable[(Bool, T)]): T = SeqUtils.oneHotMux(in)
def apply[T <: Data](sel: UInt, in: Seq[T]): T =
apply((0 until in.size).map(sel(_)), in)
def apply(sel: UInt, in: UInt): Bool = (sel & in).orR

protected def _applyImpl[T <: Data](in: Iterable[(Bool, T)])(implicit sourceInfo: SourceInfo): T =
SeqUtils.oneHotMux(in)

protected def _applyImpl[T <: Data](sel: UInt, in: Seq[T])(implicit sourceInfo: SourceInfo): T =
_applyImpl((0 until in.size).map(sel(_)), in)

protected def _applyImpl(sel: UInt, in: UInt)(implicit sourceInfo: SourceInfo): Bool = (sel & in).orR
}

/** Builds a Mux tree under the assumption that multiple select signals
* can be enabled. Priority is given to the first select signal.
*
* @example {{{
* val hotValue = chisel3.util.PriorityMux(Seq(
* io.selector(0) -> 2.U,
* io.selector(1) -> 4.U,
* io.selector(2) -> 8.U,
* io.selector(4) -> 11.U,
* ))
* }}}
* Returns the output of the Mux tree.
*/
object PriorityMux {
def apply[T <: Data](in: Seq[(Bool, T)]): T = SeqUtils.priorityMux(in)
def apply[T <: Data](sel: Seq[Bool], in: Seq[T]): T = {
require(
sel.size == in.size,
s"PriorityMux: input Seqs must have the same length, got sel ${sel.size} and in ${in.size}"
)
apply(sel.zip(in))
private[chisel3] trait PriorityMuxImpl {

protected def _applyImpl[T <: Data](in: Seq[(Bool, T)]): T = SeqUtils.priorityMux(in)

protected def _applyImpl[T <: Data](sel: Seq[Bool], in: Seq[T])(implicit sourceInfo: SourceInfo): T = {
if (sel.size != in.size) {
Builder.error(s"PriorityMux: input Seqs must have the same length, got sel ${sel.size} and in ${in.size}")
}
_applyImpl(sel.zip(in))
}
def apply[T <: Data](sel: Bits, in: Seq[T]): T = apply((0 until in.size).map(sel(_)), in)

protected def _applyImpl[T <: Data](sel: Bits, in: Seq[T])(implicit sourceInfo: SourceInfo): T =
_applyImpl((0 until in.size).map(sel(_)), in)
}

private[chisel3] trait MuxLookupImpl {
Expand Down
18 changes: 15 additions & 3 deletions src/test/scala/chiselTests/OneHotMuxSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package chiselTests
import chisel3._
import chisel3.testers.BasicTester
import chisel3.util.{Mux1H, UIntToOH}
import _root_.circt.stage.ChiselStage.emitCHIRRTL
import org.scalatest._
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers
Expand Down Expand Up @@ -32,12 +33,23 @@ class OneHotMuxSpec extends AnyFreeSpec with Matchers with ChiselRunners {
e.getMessage should include("Mux1H must have a non-empty argument")
}
"Mux1H should give a error when given different size Seqs" in {
val e = intercept[IllegalArgumentException] {
Mux1H(Seq(true.B, true.B), Seq(1.U, 2.U, 3.U))
val e = intercept[ChiselException] {
emitCHIRRTL(
new RawModule {
Mux1H(Seq(true.B, false.B), Seq(1.U, 2.U, 3.U))
},
args = Array("--throw-on-first-error")
)
}
e.getMessage should include("OneHotMuxSpec.scala") // Make sure source locator comes from this file
e.getMessage should include("Mux1H: input Seqs must have the same length, got sel 2 and in 3")
}

// The input bitvector is sign extended to the width of the sequence
"Mux1H should NOT error when given mismatched selector width and Seq size" in {
emitCHIRRTL(new RawModule {
Mux1H("b10".U(2.W), Seq(1.U, 2.U, 3.U))
})
}
}

class SimpleOneHotTester extends BasicTester {
Expand Down
17 changes: 15 additions & 2 deletions src/test/scala/chiselTests/util/PriorityMuxSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,22 @@ class PriorityMuxSpec extends ChiselFlatSpec {
}

it should "give a error when given different size Seqs" in {
val e = intercept[IllegalArgumentException] {
PriorityMux(Seq(true.B, true.B), Seq(1.U, 2.U, 3.U))
val e = intercept[ChiselException] {
emitCHIRRTL(
new RawModule {
PriorityMux(Seq(true.B, false.B), Seq(1.U, 2.U, 3.U))
},
args = Array("--throw-on-first-error")
)
}
e.getMessage should include("PriorityMuxSpec.scala") // Make sure source locator comes from this file
e.getMessage should include("PriorityMux: input Seqs must have the same length, got sel 2 and in 3")
}

// The input bitvector is sign extended to the width of the sequence
it should "NOT error when given mismatched selector width and Seq size" in {
emitCHIRRTL(new RawModule {
PriorityMux("b10".U(2.W), Seq(1.U, 2.U, 3.U))
})
}
}

0 comments on commit 2263c56

Please sign in to comment.