diff --git a/t1/src/mask/MaskCompress.scala b/t1/src/mask/MaskCompress.scala index aef7421f55..2ece87ca95 100644 --- a/t1/src/mask/MaskCompress.scala +++ b/t1/src/mask/MaskCompress.scala @@ -4,9 +4,26 @@ package org.chipsalliance.t1.rtl import chisel3._ +import chisel3.experimental.hierarchy.{instantiable, Instance, Instantiate} +import chisel3.experimental.{SerializableModule, SerializableModuleParameter} +import chisel3.properties.{AnyClassType, Path, Property} import chisel3.util._ +import org.chipsalliance.stdlib.GeneralOM -class CompressInput(parameter: T1Parameter) extends Bundle { +case class CompressParam( + datapathWidth: Int, + xLen: Int, + vLen: Int, + laneNumber: Int, + groupNumberBits: Int, + latency: Int) + extends SerializableModuleParameter + +object CompressParam { + implicit def rwP = upickle.default.macroRW[CompressParam] +} + +class CompressInput(parameter: CompressParam) extends Bundle { val maskType: Bool = Bool() val eew: UInt = UInt(2.W) val uop: UInt = UInt(3.W) @@ -14,26 +31,55 @@ class CompressInput(parameter: T1Parameter) extends Bundle { val source1: UInt = UInt(parameter.datapathWidth.W) val mask: UInt = UInt(parameter.datapathWidth.W) val source2: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W) - val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W) + val groupCounter: UInt = UInt(parameter.groupNumberBits.W) val ffoInput: UInt = UInt(parameter.laneNumber.W) val validInput: UInt = UInt(parameter.laneNumber.W) val lastCompress: Bool = Bool() } -class CompressOutput(parameter: T1Parameter) extends Bundle { +class CompressOutput(parameter: CompressParam) extends Bundle { val data: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W) val mask: UInt = UInt((parameter.laneNumber * parameter.datapathWidth / 8).W) - val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W) + val groupCounter: UInt = UInt(parameter.groupNumberBits.W) val ffoOutput: UInt = UInt(parameter.laneNumber.W) val compressValid: Bool = Bool() } -class MaskCompress(parameter: T1Parameter) extends Module { - val in: ValidIO[CompressInput] = IO(Flipped(Valid(new CompressInput(parameter)))) - val out: CompressOutput = IO(Output(new CompressOutput(parameter))) - val newInstruction: Bool = IO(Input(Bool())) - val ffoInstruction: Bool = IO(Input(Bool())) - val writeData: UInt = IO(Output(UInt(parameter.xLen.W))) +class MaskCompressInterFace(parameter: CompressParam) extends Bundle { + val clock = Input(Clock()) + val reset = Input(Reset()) + + val in: ValidIO[CompressInput] = Flipped(Valid(new CompressInput(parameter))) + val out: CompressOutput = Output(new CompressOutput(parameter)) + val newInstruction: Bool = Input(Bool()) + val ffoInstruction: Bool = Input(Bool()) + val writeData: UInt = Output(UInt(parameter.xLen.W)) + val om = Output(Property[AnyClassType]()) +} + +@instantiable +class MaskCompressOM(parameter: CompressParam) extends GeneralOM[CompressParam, MaskCompress](parameter) { + override def hasRetime: Boolean = true +} + +class MaskCompress(val parameter: CompressParam) + extends FixedIORawModule(new MaskCompressInterFace(parameter)) + with SerializableModule[CompressParam] + with ImplicitClock + with ImplicitReset { + + protected def implicitClock = io.clock + protected def implicitReset = io.reset + + val omInstance: Instance[MaskCompressOM] = Instantiate(new MaskCompressOM(parameter)) + io.om := omInstance.getPropertyReference + omInstance.retimeIn.foreach(_ := Property(Path(io.clock))) + + val in = io.in + val out = io.out + val newInstruction = io.newInstruction + val ffoInstruction = io.ffoInstruction + val writeData = io.writeData val maskSize: Int = parameter.laneNumber * parameter.datapathWidth / 8 @@ -122,7 +168,7 @@ class MaskCompress(parameter: T1Parameter) extends Module { val compressDataReg = RegInit(0.U((parameter.laneNumber * parameter.datapathWidth).W)) val compressTailValid: Bool = RegInit(false.B) - val compressWriteGroupCount: UInt = RegInit(0.U(parameter.laneParam.groupNumberBits.W)) + val compressWriteGroupCount: UInt = RegInit(0.U(parameter.groupNumberBits.W)) val compressDataVec = Seq(1, 2, 4).map { dataByte => val dataBit = dataByte * 8 val elementSizePerSet = parameter.laneNumber * parameter.datapathWidth / 8 / dataByte @@ -238,5 +284,5 @@ class MaskCompress(parameter: T1Parameter) extends Module { ffoIndex := source1SigExtend } outWire.ffoOutput := completedLeftOr | Fill(parameter.laneNumber, ffoValid) - out := RegNext(outWire, 0.U.asTypeOf(out)) + out := Pipe(true.B, outWire, parameter.latency).bits } diff --git a/t1/src/mask/MaskUnit.scala b/t1/src/mask/MaskUnit.scala index 76cd556d2a..eb67b40669 100644 --- a/t1/src/mask/MaskUnit.scala +++ b/t1/src/mask/MaskUnit.scala @@ -89,6 +89,12 @@ class MaskUnitOM(parameter: T1Parameter) extends GeneralOM[T1Parameter, MaskUnit @public val reduceUnitIn = IO(Input(Property[AnyClassType]())) reduceUnit := reduceUnitIn + + @public + val compress = IO(Output(Property[AnyClassType]())) + @public + val compressIn = IO(Input(Property[AnyClassType]())) + compress := compressIn } // TODO: no T1Parameter here. @@ -898,14 +904,24 @@ class MaskUnit(val parameter: T1Parameter) // Determine whether the data is ready val executeEnqValid: Bool = otherTypeRequestDeq && !readType + val compressParam: CompressParam = CompressParam( + parameter.datapathWidth, + parameter.xLen, + parameter.vLen, + parameter.laneNumber, + parameter.laneParam.groupNumberBits, + 1 + ) // start execute - val compressUnit: MaskCompress = Module(new MaskCompress(parameter)) - val reduceUnit = Instantiate( + val compressUnit = Instantiate(new MaskCompress(compressParam)) + val reduceUnit = Instantiate( new MaskReduce( MaskReduceParameter(parameter.datapathWidth, parameter.laneNumber, parameter.fpuEnable) ) ) omInstance.reduceUnitIn := reduceUnit.io.om.asAnyClassType + omInstance.compressIn := compressUnit.io.om.asAnyClassType + val extendUnit: MaskExtend = Module(new MaskExtend(parameter)) // todo @@ -935,28 +951,30 @@ class MaskUnit(val parameter: T1Parameter) val compressSource1: UInt = Mux1H(sew1H, vs1Split.map(_._1)) val source1Select: UInt = Mux(mv, readVS1Reg.data, compressSource1) val source1Change: Bool = Mux1H(sew1H, vs1Split.map(_._2)) - when(source1Change && compressUnit.in.fire) { + when(source1Change && compressUnit.io.in.fire) { readVS1Reg.dataValid := false.B readVS1Reg.requestSend := false.B readVS1Reg.readIndex := readVS1Reg.readIndex + 1.U } - viotaCounterAdd := compressUnit.in.fire - - compressUnit.in.valid := executeEnqValid && unitType(1) - compressUnit.in.bits.maskType := instReg.maskType - compressUnit.in.bits.eew := instReg.sew - compressUnit.in.bits.uop := instReg.decodeResult(Decoder.topUop) - compressUnit.in.bits.readFromScalar := instReg.readFromScala - compressUnit.in.bits.source1 := source1Select - compressUnit.in.bits.mask := executeElementMask - compressUnit.in.bits.source2 := source2 - compressUnit.in.bits.groupCounter := requestCounter - compressUnit.in.bits.lastCompress := lastGroup - compressUnit.in.bits.ffoInput := VecInit(exeReqReg.map(_.bits.ffo)).asUInt - compressUnit.in.bits.validInput := VecInit(exeReqReg.map(_.valid)).asUInt - compressUnit.newInstruction := instReq.valid - compressUnit.ffoInstruction := instReq.bits.decodeResult(Decoder.topUop)(2, 0) === BitPat("b11?") + viotaCounterAdd := compressUnit.io.in.fire + + compressUnit.io.clock := implicitClock + compressUnit.io.reset := implicitReset + compressUnit.io.in.valid := executeEnqValid && unitType(1) + compressUnit.io.in.bits.maskType := instReg.maskType + compressUnit.io.in.bits.eew := instReg.sew + compressUnit.io.in.bits.uop := instReg.decodeResult(Decoder.topUop) + compressUnit.io.in.bits.readFromScalar := instReg.readFromScala + compressUnit.io.in.bits.source1 := source1Select + compressUnit.io.in.bits.mask := executeElementMask + compressUnit.io.in.bits.source2 := source2 + compressUnit.io.in.bits.groupCounter := requestCounter + compressUnit.io.in.bits.lastCompress := lastGroup + compressUnit.io.in.bits.ffoInput := VecInit(exeReqReg.map(_.bits.ffo)).asUInt + compressUnit.io.in.bits.validInput := VecInit(exeReqReg.map(_.valid)).asUInt + compressUnit.io.newInstruction := instReq.valid + compressUnit.io.ffoInstruction := instReq.bits.decodeResult(Decoder.topUop)(2, 0) === BitPat("b11?") reduceUnit.io.clock := implicitClock reduceUnit.io.reset := implicitReset @@ -980,7 +998,7 @@ class MaskUnit(val parameter: T1Parameter) sink := VecInit(exeReqReg.map(_.bits.fpReduceValid.get)).asUInt } - when(reduceUnit.io.in.fire || compressUnit.in.fire) { + when(reduceUnit.io.in.fire || compressUnit.io.in.fire) { readVS1Reg.sendToExecution := true.B } @@ -1001,7 +1019,7 @@ class MaskUnit(val parameter: T1Parameter) val executeResult: UInt = Mux1H( unitType(3, 1), Seq( - compressUnit.out.data, + compressUnit.io.out.data, reduceUnit.io.out.bits.data, extendUnit.out ) @@ -1021,7 +1039,7 @@ class MaskUnit(val parameter: T1Parameter) val executeValid: Bool = Mux1H( unitType(3, 1), Seq( - compressUnit.out.compressValid, + compressUnit.io.out.compressValid, false.B, executeEnqValid ) @@ -1039,13 +1057,13 @@ class MaskUnit(val parameter: T1Parameter) val executeDeqGroupCounter: UInt = Mux1H( unitType(3, 1), Seq( - compressUnit.out.groupCounter, + compressUnit.io.out.groupCounter, requestCounter, extendGroupCount ) ) - val executeWriteByteMask: UInt = Mux(compress || ffo || mvVd, compressUnit.out.mask, executeByteMask) + val executeWriteByteMask: UInt = Mux(compress || ffo || mvVd, compressUnit.io.out.mask, executeByteMask) maskedWrite.needWAR := maskDestinationType maskedWrite.vd := instReg.vd maskedWrite.in.zipWithIndex.foreach { case (req, index) => @@ -1057,7 +1075,7 @@ class MaskUnit(val parameter: T1Parameter) req.bits.pipeData := exeReqReg(index).bits.source1 req.bits.bitMask := bitMask req.bits.groupCounter := executeDeqGroupCounter - req.bits.ffoByOther := compressUnit.out.ffoOutput(index) && ffo + req.bits.ffoByOther := compressUnit.io.out.ffoOutput(index) && ffo if (index == 0) { // reduce result when(unitType(2)) { @@ -1117,7 +1135,7 @@ class MaskUnit(val parameter: T1Parameter) val executeStageInvalid: Bool = Mux1H( unitType(3, 1), Seq( - !compressUnit.out.compressValid, + !compressUnit.io.out.compressValid, reduceUnit.io.in.ready, true.B ) @@ -1136,7 +1154,7 @@ class MaskUnit(val parameter: T1Parameter) lastReportValid, indexToOH(instReg.instructionIndex, parameter.chainingSize) ) - writeRDData := Mux(pop, reduceUnit.io.out.bits.data, compressUnit.writeData) + writeRDData := Mux(pop, reduceUnit.io.out.bits.data, compressUnit.io.writeData) // gather read state when(gatherRequestFire) {