Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FPU rounding module #728

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
47 changes: 47 additions & 0 deletions coreblocks/func_blocks/fu/fpu/fpu_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from amaranth.lib import enum


class RoundingModes(enum.Enum, shape=3):
ROUND_UP = 3
ROUND_DOWN = 2
ROUND_ZERO = 1
ROUND_NEAREST_EVEN = 0
ROUND_NEAREST_AWAY = 4


class FPUParams:
"""FPU parameters

Parameters
----------
sig_width: int
Width of significand
exp_width: int
Width of exponent
"""

def __init__(
self,
*,
sig_width: int = 24,
exp_width: int = 8,
):
self.sig_width = sig_width
self.exp_width = exp_width


class FPURoundingParams:
"""FPU rounding module signature

Parameters
-----------
fpu_params: FPUParams
FPU parameters
is_rounded:bool
This flags indicates that the input number was already rounded.
This creates simpler version of rounding module that only performs error checking and returns correct number.
"""

def __init__(self, fpu_params: FPUParams, *, is_rounded: bool = False):
self.fpu_params = fpu_params
self.is_rounded = is_rounded
257 changes: 257 additions & 0 deletions coreblocks/func_blocks/fu/fpu/fpu_rounding_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
from amaranth import *
from amaranth.lib.wiring import Component, In, Out, Signature
from transactron import TModule, Method, def_method
from coreblocks.func_blocks.fu.fpu.fpu_common import (
RoundingModes,
FPUParams,
FPURoundingParams,
)


class FPURoundingSignature(Signature):
"""FPU Rounding module signature

Parameters
----------
fpu_params: FPUParams
FPU parameters
"""

def __init__(self, *, fpu_params: FPUParams):
super().__init__(
{
"in_sign": In(1),
"in_sig": In(fpu_params.sig_width),
"in_exp": In(fpu_params.exp_width),
"rounding_mode": In(3),
"guard_bit": In(1),
"sticky_bit": In(1),
"in_errors": In(3),
"out_sign": In(1),
"out_sig": Out(fpu_params.sig_width),
"out_exp": Out(fpu_params.exp_width),
"out_error": Out(3),
}
)


class FPURoudningMethodLayout:
"""FPU Rounding module layouts for methods

Parameters
----------
fpu_params: FPUParams
FPU parameters
"""

def __init__(self, *, fpu_params: FPUParams):
self.rounding_in_layout = [
("sign", 1),
("sig", fpu_params.sig_width),
("exp", fpu_params.exp_width),
("guard_bit", 1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean round_bit?

("sticky_bit", 1),
("rounding_mode", 3),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enums can be used as signal shapes

Suggested change
("rounding_mode", 3),
("rounding_mode", RoundingModes),

("errors", 5),
("input_nan", 1),
("input_inf", 1),
]
self.rounding_out_layout = [
("sign", 1),
("sig", fpu_params.sig_width),
("exp", fpu_params.exp_width),
("errors", 5),
]


class FPUrounding(Component):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring would be helpful


fpu_rounding: FPURoundingSignature
Comment on lines +67 to +69
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the Component used? Is the interface used aynwhere?


def __init__(self, *, fpu_rounding_params: FPURoundingParams):
super().__init__({"fpu_rounding": Out(FPURoundingSignature(fpu_params=fpu_rounding_params.fpu_params))})

self.fpu_rounding_params = fpu_rounding_params
self.method_layouts = FPURoudningMethodLayout(fpu_params=self.fpu_rounding_params.fpu_params)
self.rounding_request = Method(
i=self.method_layouts.rounding_in_layout,
o=self.method_layouts.rounding_out_layout,
)
self.rtval = {}
self.max_exp = C(
2 ** (self.fpu_rounding_params.fpu_params.exp_width) - 1,
unsigned(self.fpu_rounding_params.fpu_params.exp_width),
)
self.max_normal_exp = C(
2 ** (self.fpu_rounding_params.fpu_params.exp_width) - 2,
unsigned(self.fpu_rounding_params.fpu_params.exp_width),
)
self.quiet_nan = C(
2 ** (self.fpu_rounding_params.fpu_params.sig_width - 1),
unsigned(self.fpu_rounding_params.fpu_params.sig_width),
)
self.max_sig = C(
2 ** (self.fpu_rounding_params.fpu_params.sig_width) - 1,
unsigned(self.fpu_rounding_params.fpu_params.sig_width),
)
self.add_one = Signal()
self.inc_rtnte = Signal()
self.inc_rtnta = Signal()
self.inc_rtpi = Signal()
self.inc_rtmi = Signal()

self.rounded_sig = Signal(self.fpu_rounding_params.fpu_params.sig_width + 1)
self.normalised_sig = Signal(self.fpu_rounding_params.fpu_params.sig_width)
self.rounded_exp = Signal(self.fpu_rounding_params.fpu_params.exp_width + 1)

self.final_guard_bit = Signal()
self.final_sticky_bit = Signal()

self.overflow = Signal()
self.underflow = Signal()
self.inexact = Signal()
self.tininess = Signal()
self.is_inf = Signal()
self.is_nan = Signal()
self.input_not_special = Signal()
self.rounded_inexact = Signal()

self.final_exp = Signal(self.fpu_rounding_params.fpu_params.exp_width)
self.final_sig = Signal(self.fpu_rounding_params.fpu_params.sig_width)
self.final_sign = Signal()
Comment on lines +80 to +121
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those are module internal signals/constants, there is no need to make them public. definitions could be placed in elaborate code and without self.

self.final_errors = Signal(5)

def elaborate(self, platform):
m = TModule()

@def_method(m, self.rounding_request)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, all of comb domains could be replaced with av_comb, to skip adding extra multiplexer (that enables the assignment only on rounding_request run condition).

def _(arg):

m.d.comb += self.inc_rtnte.eq(
(arg.rounding_mode == RoundingModes.ROUND_NEAREST_EVEN)
& (arg.guard_bit & (arg.sticky_bit | arg.sig[0]))
)
m.d.comb += self.inc_rtnta.eq((arg.rounding_mode == RoundingModes.ROUND_NEAREST_AWAY) & (arg.guard_bit))
m.d.comb += self.inc_rtpi.eq(
(arg.rounding_mode == RoundingModes.ROUND_UP) & (~arg.sign & (arg.guard_bit | arg.sticky_bit))
)
m.d.comb += self.inc_rtmi.eq(
(arg.rounding_mode == RoundingModes.ROUND_DOWN) & (arg.sign & (arg.guard_bit | arg.sticky_bit))
)

m.d.comb += self.add_one.eq(self.inc_rtmi | self.inc_rtnta | self.inc_rtnte | self.inc_rtpi)

if self.fpu_rounding_params.is_rounded:

m.d.comb += self.normalised_sig.eq(arg.sig)
m.d.comb += self.final_guard_bit.eq(arg.guard_bit)
m.d.comb += self.final_sticky_bit.eq(arg.sticky_bit)
m.d.comb += self.rounded_exp.eq(arg.exp)

else:

m.d.comb += self.rounded_sig.eq(arg.sig + self.add_one)

with m.If(self.rounded_sig[-1]):

m.d.comb += self.normalised_sig.eq(self.rounded_sig >> 1)
m.d.comb += self.final_guard_bit.eq(self.rounded_sig[0])
m.d.comb += self.final_sticky_bit.eq(arg.guard_bit | arg.sticky_bit)
m.d.comb += self.rounded_exp.eq(arg.exp + 1)

with m.Else():
m.d.comb += self.normalised_sig.eq(self.rounded_sig)
m.d.comb += self.final_guard_bit.eq(arg.guard_bit)
m.d.comb += self.final_sticky_bit.eq(arg.sticky_bit)
m.d.comb += self.rounded_exp.eq(arg.exp)

m.d.comb += self.rounded_inexact.eq(self.final_guard_bit | self.final_sticky_bit)
m.d.comb += self.is_nan.eq(arg.errors[0] | arg.input_nan)
m.d.comb += self.is_inf.eq(arg.errors[1] | arg.input_inf)
Comment on lines +169 to +170
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both errors ans input_nan/inf are input fields. Why is the information duplicated? Is the errors input needed at all in the input interface?

I may completely misundersand intended use case of the module, but when it is used for example in is_rounded=True mode, shouldn't it detect inf/nan conditions from the number itself?

m.d.comb += self.input_not_special.eq(~(self.is_nan) & ~(self.is_inf))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some simple cases, like in conditions, there is no need to explicitly create separate signals, but simply assign to python variable like input_not_special = ~is_nan & ~is_inf

m.d.comb += self.overflow.eq(self.input_not_special & (self.rounded_exp >= self.max_exp))
m.d.comb += self.tininess.eq(
(self.rounded_exp == 0) & (self.rounded_inexact | self.rounded_sig.any()) & (~self.normalised_sig[-1])
)
m.d.comb += self.inexact.eq(self.overflow | (self.input_not_special & self.rounded_inexact))
m.d.comb += self.underflow.eq(self.tininess & self.inexact)

with m.If(self.is_nan):

m.d.comb += self.final_exp.eq(self.max_exp)
m.d.comb += self.final_sig.eq(arg.sig)
m.d.comb += self.final_sign.eq(arg.sign)

with m.Elif(self.is_inf):

m.d.comb += self.final_exp.eq(self.max_exp)
m.d.comb += self.final_sig.eq(arg.sig)
m.d.comb += self.final_sign.eq(arg.sign)

Comment on lines +178 to +190
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If/Elif assignments are the same

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the exp part overwritten (like the number could be incorrect when flags are set), while sig value is preserved, while being critical for classification?

with m.Elif(self.overflow):

with m.If(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amaranth switch/case could be used: with m.Switch(arg.rounding_mode)

(arg.rounding_mode == RoundingModes.ROUND_NEAREST_AWAY)
| (arg.rounding_mode == RoundingModes.ROUND_NEAREST_EVEN)
):

m.d.comb += self.final_exp.eq(self.max_exp)
m.d.comb += self.final_sig.eq(0)
m.d.comb += self.final_sign.eq(arg.sign)

with m.If(arg.rounding_mode == RoundingModes.ROUND_ZERO):

m.d.comb += self.final_exp.eq(self.max_normal_exp)
m.d.comb += self.final_sig.eq(self.max_sig)
m.d.comb += self.final_sign.eq(arg.sign)

with m.If(arg.rounding_mode == RoundingModes.ROUND_DOWN):

with m.If(arg.sign):

m.d.comb += self.final_exp.eq(self.max_exp)
m.d.comb += self.final_sig.eq(0)
m.d.comb += self.final_sign.eq(arg.sign)

with m.Else():

m.d.comb += self.final_exp.eq(self.max_normal_exp)
m.d.comb += self.final_sig.eq(self.max_sig)
m.d.comb += self.final_sign.eq(arg.sign)

with m.If(arg.rounding_mode == RoundingModes.ROUND_UP):

with m.If(arg.sign):

m.d.comb += self.final_exp.eq(self.max_normal_exp)
m.d.comb += self.final_sig.eq(self.max_sig)
m.d.comb += self.final_sign.eq(arg.sign)

with m.Else():

m.d.comb += self.final_exp.eq(self.max_exp)
m.d.comb += self.final_sig.eq(0)
m.d.comb += self.final_sign.eq(arg.sign)

with m.Else():
with m.If((self.rounded_exp == 0) & (self.normalised_sig[-1] == 1)):
m.d.comb += self.final_exp.eq(1)
Comment on lines +237 to +238
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this condition?

with m.Else():
m.d.comb += self.final_exp.eq(self.rounded_exp)
m.d.comb += self.final_sig.eq(self.normalised_sig)
m.d.comb += self.final_sign.eq(arg.sign)

m.d.comb += self.final_errors[0].eq(arg.errors[0])
m.d.comb += self.final_errors[1].eq(arg.errors[1])
m.d.comb += self.final_errors[2].eq(self.overflow)
m.d.comb += self.final_errors[3].eq(self.underflow)
m.d.comb += self.final_errors[4].eq(self.inexact)

self.rtval["exp"] = self.final_exp
self.rtval["sig"] = self.final_sig
self.rtval["sign"] = self.final_sign
self.rtval["errors"] = self.final_errors

return self.rtval
Comment on lines +249 to +255
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use inline dictionary in return, like:

return {
    "exp": self.final_exp,
     ...


return m
Loading