Skip to content

Commit

Permalink
Update conditional handling to use blocks (#162)
Browse files Browse the repository at this point in the history
* update conditional handling to use blocks

* update ll files
  • Loading branch information
cqc-melf committed Jul 19, 2024
1 parent 643b790 commit c86df01
Show file tree
Hide file tree
Showing 20 changed files with 524 additions and 883 deletions.
120 changes: 56 additions & 64 deletions pytket/qir/conversion/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def __init__(
self.cregs = _retrieve_registers(self.circuit.bits, BitRegister)
self.target_gateset = self.module.gateset.base_gateset

self.block_count = 0

self.wasm_sar_dict: dict[str, str] = {}
self.wasm_sar_dict["!llvm.module.flags"] = (
'attributes #1 = { "wasm" }\n\n!llvm.module.flags'
Expand Down Expand Up @@ -313,6 +315,9 @@ def __init__(

self.additional_quantum_gates: dict[OpType, pyqir.Function] = {}

entry = self.module.module.entry_block
self.module.module.builder.insert_at_end(entry)

for creg in self.circuit.c_registers:
self._reg2ssa_var(creg, qir_int_type)

Expand Down Expand Up @@ -615,6 +620,16 @@ def conv_RangePredicateOp(
def conv_conditional(self, command: Command, op: Conditional) -> None:
condition_name = command.args[0].reg_name

entry_point = self.module.module.entry_point

condb = pyqir.BasicBlock(
self.module.module.context, f"condb{self.block_count}", entry_point
)
contb = pyqir.BasicBlock(
self.module.module.context, f"contb{self.block_count}", entry_point
)
self.block_count = self.block_count + 1

if op.op.type == OpType.CircBox:
conditional_circuit = self._decompose_conditional_circ_box(
op.op, command.args[op.width :]
Expand All @@ -623,23 +638,8 @@ def conv_conditional(self, command: Command, op: Conditional) -> None:
condition_name = command.args[0].reg_name

if op.width == 1: # only one conditional bit
condition_bit_index = command.args[0].index[0]

def condition_block_true() -> None:
"""
Populate recursively the module with the contents of the
conditional sub-circuit when the condition is True.
"""
if op.value == 1:
self.subcircuit_to_module(conditional_circuit)

def condition_block_false() -> None:
"""
Populate recursively the module with the contents of the
conditional sub-circuit when the condition is False.
"""
if op.value == 0:
self.subcircuit_to_module(conditional_circuit)
condition_bit_index = command.args[0].index[0]

ssabool = self.module.builder.call(
self.get_creg_bit,
Expand All @@ -649,11 +649,18 @@ def condition_block_false() -> None:
],
)

self.module.module.builder.if_(
ssabool,
true=lambda: condition_block_true(),
false=lambda: condition_block_false(),
)
if op.value == 1:
self.module.module.builder.condbr(ssabool, condb, contb)
self.module.module.builder.insert_at_end(condb)
self.subcircuit_to_module(conditional_circuit)

if op.value == 0:
self.module.module.builder.condbr(ssabool, contb, condb)
self.module.module.builder.insert_at_end(condb)
self.subcircuit_to_module(conditional_circuit)

self.module.module.builder.br(contb)
self.module.module.builder.insert_at_end(contb)

else:
for i in range(op.width):
Expand All @@ -673,45 +680,27 @@ def condition_block_false() -> None:
"conditional can only work with one entire register"
)

def condition_block() -> None:
"""
Populate recursively the module with the contents of the
conditional sub-circuit when the condition is True.
"""
self.subcircuit_to_module(conditional_circuit)

ssabool = self.module.module.builder.icmp(
pyqir.IntPredicate.EQ,
pyqir.const(self.qir_int_type, op.value),
self._get_i64_ssa_reg(condition_name),
)

self.module.module.builder.if_(
ssabool,
true=lambda: condition_block(),
)
self.module.module.builder.condbr(ssabool, condb, contb)

self.module.module.builder.insert_at_end(condb)

self.subcircuit_to_module(conditional_circuit)

self.module.module.builder.br(contb)
self.module.module.builder.insert_at_end(contb)

else:
condition_name = command.args[0].reg_name

if op.width == 1: # only one conditional bit
condition_bit_index = command.args[0].index[0]

def condition_block_true() -> None:
"""
Populate recursively the module with the contents of the
conditional sub-circuit when the condition is True.
"""
if op.value == 1:
self.command_to_module(op.op, command.args[op.width :])

def condition_block_false() -> None:
"""
Populate recursively the module with the contents of the
conditional sub-circuit when the condition is False.
"""
if op.value == 0:
self.command_to_module(op.op, command.args[op.width :])

ssabool = self.module.builder.call(
self.get_creg_bit,
[
Expand All @@ -720,11 +709,18 @@ def condition_block_false() -> None:
],
)

self.module.module.builder.if_(
ssabool,
true=lambda: condition_block_true(),
false=lambda: condition_block_false(),
)
if op.value == 1:
self.module.module.builder.condbr(ssabool, condb, contb)
self.module.module.builder.insert_at_end(condb)
self.command_to_module(op.op, command.args[op.width :])

if op.value == 0:
self.module.module.builder.condbr(ssabool, contb, condb)
self.module.module.builder.insert_at_end(condb)
self.command_to_module(op.op, command.args[op.width :])

self.module.module.builder.br(contb)
self.module.module.builder.insert_at_end(contb)

else:
for i in range(op.width):
Expand All @@ -744,23 +740,19 @@ def condition_block_false() -> None:
"conditional can only work with one entire register"
)

def condition_block() -> None:
"""
Populate recursively the module with the contents of the
conditional sub-circuit when the condition is True.
"""
self.command_to_module(op.op, command.args[op.width :])

ssabool = self.module.module.builder.icmp(
pyqir.IntPredicate.EQ,
pyqir.const(self.qir_int_type, op.value),
self._get_i64_ssa_reg(condition_name),
)

self.module.module.builder.if_(
ssabool,
true=lambda: condition_block(),
)
self.module.module.builder.condbr(ssabool, condb, contb)
self.module.module.builder.insert_at_end(condb)

self.command_to_module(op.op, command.args[op.width :])

self.module.module.builder.br(contb)
self.module.module.builder.insert_at_end(contb)

def conv_WASMOp(self, op: WASMOp, args: Union[Bit, Qubit]) -> None:
paramreg, resultreg = self._get_c_regs_from_com(op, args)
Expand Down
132 changes: 48 additions & 84 deletions tests/qir/test_pytket_qir_14.ll
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,12 @@ entry:
%43 = and i1 %40, %42
call void @set_creg_bit(i1* %3, i64 8, i1 %43)
%44 = call i1 @get_creg_bit(i1* %0, i64 0)
br i1 %44, label %then, label %else
br i1 %44, label %condb0, label %contb0

then: ; preds = %entry
br label %continue
condb0: ; preds = %entry
br label %contb0

else: ; preds = %entry
br label %continue

continue: ; preds = %else, %then
contb0: ; preds = %condb0, %entry
%45 = call i1 @get_creg_bit(i1* %0, i64 0)
%46 = call i1 @get_creg_bit(i1* %1, i64 0)
%47 = xor i1 %45, %46
Expand All @@ -138,126 +135,93 @@ continue: ; preds = %else, %then
%62 = icmp eq i64 1, %61
call void @set_creg_bit(i1* %3, i64 3, i1 %62)
%63 = call i1 @get_creg_bit(i1* %3, i64 0)
br i1 %63, label %then1, label %else2
br i1 %63, label %condb1, label %contb1

then1: ; preds = %continue
condb1: ; preds = %contb0
call void @__quantum__qis__x__body(%Qubit* null)
br label %continue3

else2: ; preds = %continue
br label %continue3
br label %contb1

continue3: ; preds = %else2, %then1
contb1: ; preds = %condb1, %contb0
%64 = call i1 @get_creg_bit(i1* %3, i64 1)
br i1 %64, label %then4, label %else5
br i1 %64, label %condb2, label %contb2

then4: ; preds = %continue3
condb2: ; preds = %contb1
call void @__quantum__qis__x__body(%Qubit* null)
br label %continue6

else5: ; preds = %continue3
br label %continue6
br label %contb2

continue6: ; preds = %else5, %then4
contb2: ; preds = %condb2, %contb1
%65 = call i1 @get_creg_bit(i1* %3, i64 2)
br i1 %65, label %then7, label %else8
br i1 %65, label %condb3, label %contb3

then7: ; preds = %continue6
condb3: ; preds = %contb2
call void @__quantum__qis__x__body(%Qubit* null)
br label %continue9
br label %contb3

else8: ; preds = %continue6
br label %continue9

continue9: ; preds = %else8, %then7
contb3: ; preds = %condb3, %contb2
%66 = call i1 @get_creg_bit(i1* %3, i64 3)
br i1 %66, label %then10, label %else11
br i1 %66, label %condb4, label %contb4

then10: ; preds = %continue9
condb4: ; preds = %contb3
call void @__quantum__qis__x__body(%Qubit* null)
br label %continue12

else11: ; preds = %continue9
br label %continue12
br label %contb4

continue12: ; preds = %else11, %then10
contb4: ; preds = %condb4, %contb3
%67 = call i1 @get_creg_bit(i1* %0, i64 0)
br i1 %67, label %then13, label %else14
br i1 %67, label %condb5, label %contb5

then13: ; preds = %continue12
condb5: ; preds = %contb4
call void @__quantum__qis__x__body(%Qubit* null)
br label %continue15

else14: ; preds = %continue12
br label %continue15
br label %contb5

continue15: ; preds = %else14, %then13
contb5: ; preds = %condb5, %contb4
%68 = call i1 @get_creg_bit(i1* %3, i64 4)
br i1 %68, label %then16, label %else17
br i1 %68, label %contb6, label %condb6

then16: ; preds = %continue15
br label %continue18

else17: ; preds = %continue15
condb6: ; preds = %contb5
call void @__quantum__qis__x__body(%Qubit* null)
br label %continue18
br label %contb6

continue18: ; preds = %else17, %then16
contb6: ; preds = %condb6, %contb5
%69 = call i1 @get_creg_bit(i1* %0, i64 0)
br i1 %69, label %then19, label %else20

then19: ; preds = %continue18
br label %continue21
br i1 %69, label %contb7, label %condb7

else20: ; preds = %continue18
condb7: ; preds = %contb6
call void @__quantum__qis__x__body(%Qubit* null)
br label %continue21
br label %contb7

continue21: ; preds = %else20, %then19
contb7: ; preds = %condb7, %contb6
%70 = call i1 @get_creg_bit(i1* %3, i64 5)
br i1 %70, label %then22, label %else23
br i1 %70, label %condb8, label %contb8

then22: ; preds = %continue21
condb8: ; preds = %contb7
call void @__quantum__qis__x__body(%Qubit* null)
br label %continue24

else23: ; preds = %continue21
br label %continue24
br label %contb8

continue24: ; preds = %else23, %then22
contb8: ; preds = %condb8, %contb7
%71 = call i1 @get_creg_bit(i1* %3, i64 6)
br i1 %71, label %then25, label %else26
br i1 %71, label %condb9, label %contb9

then25: ; preds = %continue24
condb9: ; preds = %contb8
call void @__quantum__qis__x__body(%Qubit* null)
br label %continue27
br label %contb9

else26: ; preds = %continue24
br label %continue27

continue27: ; preds = %else26, %then25
contb9: ; preds = %condb9, %contb8
%72 = call i1 @get_creg_bit(i1* %3, i64 7)
br i1 %72, label %then28, label %else29
br i1 %72, label %condb10, label %contb10

then28: ; preds = %continue27
condb10: ; preds = %contb9
call void @__quantum__qis__x__body(%Qubit* null)
br label %continue30

else29: ; preds = %continue27
br label %continue30
br label %contb10

continue30: ; preds = %else29, %then28
contb10: ; preds = %condb10, %contb9
%73 = call i1 @get_creg_bit(i1* %3, i64 8)
br i1 %73, label %then31, label %else32
br i1 %73, label %condb11, label %contb11

then31: ; preds = %continue30
condb11: ; preds = %contb10
call void @__quantum__qis__x__body(%Qubit* null)
br label %continue33

else32: ; preds = %continue30
br label %continue33
br label %contb11

continue33: ; preds = %else32, %then31
contb11: ; preds = %condb11, %contb10
call void @__quantum__rt__tuple_start_record_output()
%74 = call i64 @get_int_from_creg(i1* %0)
call void @__quantum__rt__int_record_output(i64 %74, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @0, i32 0, i32 0))
Expand Down
Loading

0 comments on commit c86df01

Please sign in to comment.