-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathFPAdd.scala
268 lines (215 loc) · 8.88 KB
/
FPAdd.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
/*
Stage1 : find the difference between exponents
Stage2 : shift the proper mantissa
Stage3 : Add/subtract mantissa, check overflow
Stage4 : Normalize mantissa and exponent
*/
package fpDivision
import chisel3._
import chisel3.util.Cat
import chisel3.util.Reverse
import chisel3.util.PriorityEncoder
import FloatUtils.{floatToBigInt, doubleToBigInt, getExpMantWidths,
floatAdd, doubleAdd}
class SatLeftShift(val m: Int, val n: Int) extends Module {
val io = IO(new Bundle {
val shiftin = Input(UInt(width = m.W))
val shiftby = Input(UInt(width = n.W))
val shiftout = Output(UInt(width = m.W))
})
io.shiftout := Mux(io.shiftby > UInt(m), UInt(0), io.shiftin >> io.shiftby)
}
class FPAddStage1(val n: Int) extends Module {
val (expWidth, mantWidth) = getExpMantWidths(n)
val io = IO(new Bundle {
val a = Input(Bits(width = n.W))
val b = Input(Bits(width = n.W))
val b_larger = Output(Bool())
val mant_shift = Output(UInt(width = expWidth.W))
val exp = Output(UInt(width = expWidth.W))
val manta = Output(UInt(width = (mantWidth + 1).W))
val mantb = Output(UInt(width = (mantWidth + 1).W))
val sign = Output(Bool())
val sub = Output(Bool())
})
val a_wrap = new FloatWrapper(io.a)
val b_wrap = new FloatWrapper(io.b)
// we need to add a bit to the beginning before subtracting
// so that we can catch if it becomes negative
val ext_exp_a = Cat(UInt(0, 1), a_wrap.exponent)
val ext_exp_b = Cat(UInt(0, 1), b_wrap.exponent)
val exp_diff = ext_exp_a - ext_exp_b
//printf("Stage1 exp_diff: %d\n", exp_diff)
val reg_b_larger = Reg(Bool())
val reg_mant_shift = Reg(UInt(width = expWidth.W))
val reg_exp = Reg(UInt(width = expWidth.W))
val reg_manta = Reg(next = a_wrap.mantissa)
val reg_mantb = Reg(next = b_wrap.mantissa)
val reg_sign = Reg(Bool())
val reg_sub = Reg(next = (a_wrap.sign ^ b_wrap.sign))
// In stage 1, we subtract the exponents
// This will tell us which number is larger
// as well as what we need to shift the smaller mantissa by
// b is larger
when (exp_diff(expWidth) === UInt(1)) {
// absolute value
reg_mant_shift := -exp_diff(expWidth - 1, 0)
//mant_shift := (~exp_diff) + UInt(1)
reg_b_larger := Bool(true)
reg_exp := b_wrap.exponent
reg_sign := b_wrap.sign
} .otherwise {
reg_mant_shift := exp_diff(expWidth - 1, 0)
reg_b_larger := Bool(false)
reg_exp := a_wrap.exponent
reg_sign := a_wrap.sign
}
io.mant_shift := reg_mant_shift
io.b_larger := reg_b_larger
io.exp := reg_exp
io.manta := reg_manta
io.mantb := reg_mantb
io.sign := reg_sign
io.sub := reg_sub
}
class FPAddStage2(val n: Int) extends Module {
val (expWidth, mantWidth) = getExpMantWidths(n)
val io = IO(new Bundle {
val manta_in = Input(UInt(width = (mantWidth + 1).W))
val mantb_in = Input(UInt(width = (mantWidth + 1).W))
val exp_in = Input(UInt(width = expWidth.W))
val mant_shift = Input(UInt(width = expWidth.W))
val b_larger = Input(Bool())
val sign_in = Input(Bool())
val sub_in = Input(Bool())
val manta_out = Output(UInt(width = (mantWidth + 1).W))
val mantb_out = Output(UInt(width = (mantWidth + 1).W))
val exp_out = Output(UInt(width = expWidth.W))
val sign_out = Output(Bool())
val sub_out = Output(Bool())
})
// in stage 2 we shift the appropriate mantissa by the amount
// detected in stage 1
val larger_mant = Wire(UInt(width = (mantWidth + 1).W))
val smaller_mant = Wire(UInt(width = (mantWidth + 1).W))
when (io.b_larger) {
larger_mant := io.mantb_in
smaller_mant := io.manta_in
} .otherwise {
larger_mant := io.manta_in
smaller_mant := io.mantb_in
}
//printf("mant_shift: %d\n", io.mant_shift)
val shifted_mant = Mux(io.mant_shift >= UInt(mantWidth + 1), 0.U, smaller_mant >> io.mant_shift)
val reg_manta = Reg(next = larger_mant)
val reg_mantb = Reg(next = shifted_mant)
val reg_sign = Reg(next = io.sign_in)
val reg_sub = Reg(next = io.sub_in)
val reg_exp = Reg(next = io.exp_in)
io.manta_out := reg_manta
io.mantb_out := reg_mantb
io.sign_out := reg_sign
io.sub_out := reg_sub
io.exp_out := reg_exp
//printf("Stage2 large mantissa: %d small mantissa: %d, shitfted mantissa: %d \n", larger_mant, smaller_mant, shifted_mant)
}
class FPAddStage3(val n: Int) extends Module {
val (expWidth, mantWidth) = getExpMantWidths(n)
val io = IO(new Bundle {
val manta = Input(UInt(width = (mantWidth + 1).W))
val mantb = Input(UInt(width = (mantWidth + 1).W))
val exp_in = Input(UInt(width = expWidth.W))
val sign_in = Input(Bool())
val sub = Input(Bool())
val mant_out = Output(UInt(width = (mantWidth + 1).W))
val sign_out = Output(Bool())
val exp_out = Output(UInt(width = expWidth.W))
})
// in stage 3 we subtract or add the mantissas
// we must also detect overflows and adjust sign/exponent appropriately
val manta_ext = Cat(UInt(0, 1), io.manta)
val mantb_ext = Cat(UInt(0, 1), io.mantb)
val mant_sum = Mux(io.sub, manta_ext - mantb_ext, manta_ext + mantb_ext)
//printf("Stage3 mant a: %d, mant b: %d, mantRes: %d\n", manta_ext, mantb_ext, mant_sum)
// here we drop the overflow bit
val reg_mant = Reg(UInt(width = (mantWidth + 1).W))
val reg_sign = Reg(Bool())
val reg_exp = Reg(UInt(width = expWidth.W))
// this may happen if the operands were of opposite sign
// but had the same exponent
when (mant_sum(mantWidth + 1) === 1.U) {
when (io.sub) {
reg_mant := -mant_sum(mantWidth, 0)
reg_sign := !io.sign_in
reg_exp := io.exp_in
} .otherwise {
// if the sum overflowed, we need to shift back by one
// and increment the exponent
reg_mant := mant_sum(mantWidth + 1, 1)
reg_exp := io.exp_in + 1.U
reg_sign := io.sign_in
}
} .otherwise {
reg_mant := mant_sum(mantWidth, 0)
reg_sign := io.sign_in
reg_exp := io.exp_in
}
io.sign_out := reg_sign
io.exp_out := reg_exp
io.mant_out := reg_mant
//printf("stage3 sign: %d, exp: %d, mant: %d\n", reg_sign, reg_exp, reg_mant)
}
class FPAddStage4(val n: Int) extends Module {
val (expWidth, mantWidth) = getExpMantWidths(n)
val io = IO(new Bundle {
val exp_in = Input(UInt(width = expWidth.W))
val mant_in = Input(UInt(width = (mantWidth + 1).W))
val exp_out = Output(UInt(width = expWidth.W))
val mant_out = Output(UInt(width = mantWidth.W))
})
// finally in stage 4 we normalize mantissa and exponent
// we need to reverse the sum, since we want the find the most
// significant 1 instead of the least significant 1
val norm_shift = PriorityEncoder(Reverse(io.mant_in))
// if the mantissa sum is zero, result mantissa and exponent should be zero
when (io.mant_in === 0.U) {
io.mant_out := 0.U
io.exp_out := 0.U
} .otherwise {
io.mant_out := (io.mant_in << norm_shift)(mantWidth - 1, 0)
io.exp_out := io.exp_in - norm_shift
}
//printf("Stage4 norm_shift: %d, mant out: %d, exp out: %d \n", norm_shift, io.mant_out, io.exp_out)
}
class FPAdd(val n: Int) extends Module {
val io = IO(new Bundle {
val a = Input(Bits(width = n.W))
val b = Input(Bits(width = n.W))
val res = Output(Bits(width = n.W))
})
val (expWidth, mantWidth) = getExpMantWidths(n)
val stage1 = Module(new FPAddStage1(n))
stage1.io.a := io.a
stage1.io.b := io.b
val stage2 = Module(new FPAddStage2(n))
stage2.io.manta_in := stage1.io.manta
stage2.io.mantb_in := stage1.io.mantb
stage2.io.exp_in := stage1.io.exp
stage2.io.sign_in := stage1.io.sign
stage2.io.sub_in := stage1.io.sub
stage2.io.b_larger := stage1.io.b_larger
stage2.io.mant_shift := stage1.io.mant_shift
val stage3 = Module(new FPAddStage3(n))
stage3.io.manta := stage2.io.manta_out
stage3.io.mantb := stage2.io.mantb_out
stage3.io.exp_in := stage2.io.exp_out
stage3.io.sign_in := stage2.io.sign_out
stage3.io.sub := stage2.io.sub_out
val stage4 = Module(new FPAddStage4(n))
stage4.io.exp_in := stage3.io.exp_out
stage4.io.mant_in := stage3.io.mant_out
io.res := Cat(stage3.io.sign_out, stage4.io.exp_out, stage4.io.mant_out)
//printf("FPAdd result: sign: %d, exp: %d, mant: %d res: %d\n", stage3.io.sign_out, stage4.io.exp_out, stage4.io.mant_out, io.res.toUInt())
}
class FPAdd32 extends FPAdd(32) {}
class FPAdd64 extends FPAdd(64) {}