Skip to content

Commit

Permalink
routing: refactor update payment state tests
Browse files Browse the repository at this point in the history
This commit refactors the resumePayment to extract some logics back to
paymentState so that the code is more testable. It also adds unit tests
for paymentState, and breaks the original MPPayment tests into independent tests
so that it's easier to maintain and debug. All the new tests are built
using mock so that the control flow is eaiser to setup and change.
  • Loading branch information
yyforyongyu committed Jun 23, 2021
1 parent e79e46e commit cd35981
Show file tree
Hide file tree
Showing 4 changed files with 1,262 additions and 358 deletions.
40 changes: 34 additions & 6 deletions routing/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) (

type mockPaymentAttemptDispatcher struct {
mock.Mock

resultChan chan *htlcswitch.PaymentResult
}

var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
Expand All @@ -548,8 +550,11 @@ func (m *mockPaymentAttemptDispatcher) GetPaymentResult(attemptID uint64,
paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
<-chan *htlcswitch.PaymentResult, error) {

args := m.Called(attemptID, paymentHash, deobfuscator)
return args.Get(0).(<-chan *htlcswitch.PaymentResult), args.Error(1)
m.Called(attemptID, paymentHash, deobfuscator)

// Instead of returning the mocked returned values, we need to return
// the chan resultChan so it can be converted into a read-only chan.
return m.resultChan, nil
}

func (m *mockPaymentAttemptDispatcher) CleanStore(
Expand All @@ -568,7 +573,7 @@ var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil)
func (m *mockPaymentSessionSource) NewPaymentSession(
payment *LightningPayment) (PaymentSession, error) {

args := m.Called(m)
args := m.Called(payment)
return args.Get(0).(PaymentSession), args.Error(1)
}

Expand All @@ -586,6 +591,8 @@ func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession {

type mockMissionControl struct {
mock.Mock

failReason *channeldb.FailureReason
}

var _ MissionController = (*mockMissionControl)(nil)
Expand All @@ -596,8 +603,7 @@ func (m *mockMissionControl) ReportPaymentFail(
*channeldb.FailureReason, error) {

args := m.Called(paymentID, rt, failureSourceIdx, failure)
return args.Get(0).(*channeldb.FailureReason), args.Error(1)

return m.failReason, args.Error(1)
}

func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64,
Expand Down Expand Up @@ -642,6 +648,7 @@ func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,

type mockControlTower struct {
mock.Mock
sync.Mutex
}

var _ ControlTower = (*mockControlTower)(nil)
Expand All @@ -656,6 +663,9 @@ func (m *mockControlTower) InitPayment(phash lntypes.Hash,
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
a *channeldb.HTLCAttemptInfo) error {

m.Lock()
defer m.Unlock()

args := m.Called(phash, a)
return args.Error(0)
}
Expand All @@ -664,36 +674,54 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
*channeldb.HTLCAttempt, error) {

m.Lock()
defer m.Unlock()

args := m.Called(phash, pid, settleInfo)
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
}

func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {

m.Lock()
defer m.Unlock()

args := m.Called(phash, pid, failInfo)
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
}

func (m *mockControlTower) Fail(phash lntypes.Hash,
reason channeldb.FailureReason) error {

m.Lock()
defer m.Unlock()

args := m.Called(phash, reason)
return args.Error(0)
}

func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
*channeldb.MPPayment, error) {

m.Lock()
defer m.Unlock()
args := m.Called(phash)

// Type assertion on nil will fail, so we check and return here.
if args.Get(0) == nil {
return nil, args.Error(1)
}

return args.Get(0).(*channeldb.MPPayment), args.Error(1)
// Make a copy of the payment here to avoid data race.
p := args.Get(0).(*channeldb.MPPayment)
payment := &channeldb.MPPayment{
FailureReason: p.FailureReason,
}
payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs))
copy(payment.HTLCs, p.HTLCs)

return payment, args.Error(1)
}

func (m *mockControlTower) FetchInFlightPayments() (
Expand Down
161 changes: 92 additions & 69 deletions routing/payment_lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,53 @@ type paymentState struct {
numShardsInFlight int
remainingAmt lnwire.MilliSatoshi
remainingFees lnwire.MilliSatoshi
terminate bool

// terminate indicates the payment is in its final stage and no more
// shards should be launched. This value is true if we have an HTLC
// settled or the payment has an error.
terminate bool
}

// terminated returns a bool to indicate there are no further actions needed
// and we should return what we have, either the payment preimage or the
// payment error.
func (ps paymentState) terminated() bool {
// If the payment is in final stage and we have no in flight shards to
// wait result for, we consider the whole action terminated.
return ps.terminate && ps.numShardsInFlight == 0
}

// needWaitForShards returns a bool to specify whether we need to wait for the
// outcome of the shanrdHandler.
func (ps paymentState) needWaitForShards() bool {
// If we have in flight shards and the payment is in final stage, we
// need to wait for the outcomes from the shards. Or if we have no more
// money to be sent, we need to wait for the already launched shards.
if ps.numShardsInFlight == 0 {
return false
}
return ps.terminate || ps.remainingAmt == 0
}

// paymentState uses the passed payment to find the latest information we need
// to act on every iteration of the payment loop.
func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) (
// updatePaymentState will fetch db for the payment to find the latest
// information we need to act on every iteration of the payment loop and update
// the paymentState.
func (p *paymentLifecycle) updatePaymentState() (*channeldb.MPPayment,
*paymentState, error) {

// Fetch the latest payment from db.
payment, err := p.router.cfg.Control.FetchPayment(p.identifier)
if err != nil {
return nil, nil, err
}

// Fetch the total amount and fees that has already been sent in
// settled and still in-flight shards.
sentAmt, fees := payment.SentAmt()

// Sanity check we haven't sent a value larger than the payment amount.
if sentAmt > p.totalAmount {
return nil, fmt.Errorf("amount sent %v exceeds "+
return nil, nil, fmt.Errorf("amount sent %v exceeds "+
"total amount %v", sentAmt, p.totalAmount)
}

Expand All @@ -74,13 +106,15 @@ func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) (
// have returned with a result.
terminate := settle != nil || failure != nil

activeShards := payment.InFlightHTLCs()
return &paymentState{
numShardsInFlight: len(activeShards),
// Update the payment state.
state := &paymentState{
numShardsInFlight: len(payment.InFlightHTLCs()),
remainingAmt: p.totalAmount - sentAmt,
remainingFees: feeBudget,
terminate: terminate,
}, nil
}

return payment, state, nil
}

// resumePayment resumes the paymentLifecycle from the current state.
Expand All @@ -102,9 +136,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
// If we had any existing attempts outstanding, we'll start by spinning
// up goroutines that'll collect their results and deliver them to the
// lifecycle loop below.
payment, err := p.router.cfg.Control.FetchPayment(
p.identifier,
)
payment, _, err := p.updatePaymentState()
if err != nil {
return [32]byte{}, nil, err
}
Expand All @@ -128,34 +160,30 @@ lifecycle:
return [32]byte{}, nil, err
}

// We start every iteration by fetching the lastest state of
// the payment from the ControlTower. This ensures that we will
// act on the latest available information, whether we are
// resuming an existing payment or just sent a new attempt.
payment, err := p.router.cfg.Control.FetchPayment(
p.identifier,
)
if err != nil {
return [32]byte{}, nil, err
}

// Using this latest state of the payment, calculate
// information about our active shards and terminal conditions.
state, err := p.paymentState(payment)
// We update the payment state on every iteration. Since the
// payment state is affected by multiple goroutines (ie,
// collectResultAsync), it is NOT guaranteed that we always
// have the latest state here. This is fine as long as the
// state is consistent as a whole.
payment, currentState, err := p.updatePaymentState()
if err != nil {
return [32]byte{}, nil, err
}

log.Debugf("Payment %v in state terminate=%v, "+
"active_shards=%v, rem_value=%v, fee_limit=%v",
p.identifier, state.terminate, state.numShardsInFlight,
state.remainingAmt, state.remainingFees)
p.identifier, currentState.terminate,
currentState.numShardsInFlight,
currentState.remainingAmt, currentState.remainingFees,
)

// TODO(yy): sanity check all the states to make sure
// everything is expected.
switch {

// We have a terminal condition and no active shards, we are
// ready to exit.
case state.terminate && state.numShardsInFlight == 0:
case currentState.terminated():
// Find the first successful shard and return
// the preimage and route.
for _, a := range payment.HTLCs {
Expand All @@ -170,7 +198,7 @@ lifecycle:
// If we either reached a terminal error condition (but had
// active shards still) or there is no remaining value to send,
// we'll wait for a shard outcome.
case state.terminate || state.remainingAmt == 0:
case currentState.needWaitForShards():
// We still have outstanding shards, so wait for a new
// outcome to be available before re-evaluating our
// state.
Expand Down Expand Up @@ -212,8 +240,9 @@ lifecycle:

// Create a new payment attempt from the given payment session.
rt, err := p.paySession.RequestRoute(
state.remainingAmt, state.remainingFees,
uint32(state.numShardsInFlight), uint32(p.currentHeight),
currentState.remainingAmt, currentState.remainingFees,
uint32(currentState.numShardsInFlight),
uint32(p.currentHeight),
)
if err != nil {
log.Warnf("Failed to find route for payment %v: %v",
Expand All @@ -227,7 +256,7 @@ lifecycle:
// There is no route to try, and we have no active
// shards. This means that there is no way for us to
// send the payment, so mark it failed with no route.
if state.numShardsInFlight == 0 {
if currentState.numShardsInFlight == 0 {
failureCode := routeErr.FailureReason()
log.Debugf("Marking payment %v permanently "+
"failed with no route: %v",
Expand All @@ -253,22 +282,11 @@ lifecycle:

// If this route will consume the last remeining amount to send
// to the receiver, this will be our last shard (for now).
lastShard := rt.ReceiverAmt() == state.remainingAmt
lastShard := rt.ReceiverAmt() == currentState.remainingAmt

// We found a route to try, launch a new shard.
attempt, outcome, err := shardHandler.launchShard(rt, lastShard)
switch {
// We may get a terminal error if we've processed a shard with
// a terminal state (settled or permanent failure), while we
// were pathfinding. We know we're in a terminal state here,
// so we can continue and wait for our last shards to return.
case err == channeldb.ErrPaymentTerminal:
log.Infof("Payment %v in terminal state, abandoning "+
"shard", p.identifier)

continue lifecycle

case err != nil:
if err != nil {
return [32]byte{}, nil, err
}

Expand Down Expand Up @@ -297,6 +315,7 @@ lifecycle:
// Now that the shard was successfully sent, launch a go
// routine that will handle its result when its back.
shardHandler.collectResultAsync(attempt)

}
}

Expand Down Expand Up @@ -437,12 +456,30 @@ type shardResult struct {
}

// collectResultAsync launches a goroutine that will wait for the result of the
// given HTLC attempt to be available then handle its result. Note that it will
// fail the payment with the control tower if a terminal error is encountered.
// given HTLC attempt to be available then handle its result. It will fail the
// payment with the control tower if a terminal error is encountered.
func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) {

// errToSend is the error to be sent to sh.shardErrors.
var errToSend error

// handleResultErr is a function closure must be called using defer. It
// finishes collecting result by updating the payment state and send
// the error (or nil) to sh.shardErrors.
handleResultErr := func() {
// Send the error or quit.
select {
case p.shardErrors <- errToSend:
case <-p.router.quit:
case <-p.quit:
}

p.wg.Done()
}

p.wg.Add(1)
go func() {
defer p.wg.Done()
defer handleResultErr()

// Block until the result is available.
result, err := p.collectResult(attempt)
Expand All @@ -456,32 +493,18 @@ func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) {
attempt.AttemptID, p.identifier, err)
}

select {
case p.shardErrors <- err:
case <-p.router.quit:
case <-p.quit:
}
// Overwrite errToSend and return.
errToSend = err
return
}

// If a non-critical error was encountered handle it and mark
// the payment failed if the failure was terminal.
if result.err != nil {
err := p.handleSendError(attempt, result.err)
if err != nil {
select {
case p.shardErrors <- err:
case <-p.router.quit:
case <-p.quit:
}
return
}
}

select {
case p.shardErrors <- nil:
case <-p.router.quit:
case <-p.quit:
// Overwrite errToSend and return. Notice that the
// errToSend could be nil here.
errToSend = p.handleSendError(attempt, result.err)
return
}
}()
}
Expand Down
Loading

0 comments on commit cd35981

Please sign in to comment.