Skip to content

Commit

Permalink
mm: Audit mutexes
Browse files Browse the repository at this point in the history
Fixes of mutex use in market making code.
  • Loading branch information
martonp committed Jul 23, 2024
1 parent fa356ff commit 648ded5
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 17 deletions.
56 changes: 43 additions & 13 deletions client/mm/exchange_adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ func withdrawalBalanceEffects(tx *asset.WalletTransaction, cexDebit uint64, asse
func (w *pendingWithdrawal) balanceEffects() (dex, cex *BalanceEffects) {
w.txMtx.RLock()
defer w.txMtx.RUnlock()

return withdrawalBalanceEffects(w.tx, w.amtWithdrawn, w.assetID)
}

Expand Down Expand Up @@ -516,6 +517,10 @@ func (u *unifiedExchangeAdaptor) withPause(f func() error) error {
return u.botLoop.ConnectOnce(u.ctx)
}

// logBalanceAdjustments logs a trace log of balance adjustments and updated
// settled balances.
//
// balancesMtx must be read locked when calling this function.
func (u *unifiedExchangeAdaptor) logBalanceAdjustments(dexDiffs, cexDiffs map[uint32]int64, reason string) {
if u.log.Level() > dex.LevelTrace {
return
Expand Down Expand Up @@ -552,10 +557,10 @@ func (u *unifiedExchangeAdaptor) logBalanceAdjustments(dexDiffs, cexDiffs map[ui

writeLine("Updated settled balances:")
writeLine(" DEX:")

for assetID, bal := range u.baseDexBalances {
writeLine(" " + format(assetID, bal, false))
}

if len(u.baseCexBalances) > 0 {
writeLine(" CEX:")
for assetID, bal := range u.baseCexBalances {
Expand Down Expand Up @@ -690,14 +695,13 @@ type dexOrderInfo struct {
// updateDEXOrderEvent updates the event log with the current state of a
// pending DEX order and sends an event notification.
func (u *unifiedExchangeAdaptor) updateDEXOrderEvent(o *pendingDEXOrder, complete bool) {
o.txsMtx.RLock()
transactions := make([]*asset.WalletTransaction, 0, len(o.swaps)+len(o.redeems)+len(o.refunds))
addTxs := func(txs map[string]*asset.WalletTransaction) {
for _, tx := range txs {
transactions = append(transactions, tx)
}
}

o.txsMtx.RLock()
addTxs(o.swaps)
addTxs(o.redeems)
addTxs(o.refunds)
Expand Down Expand Up @@ -1283,7 +1287,7 @@ type BotBalances struct {
CEX *BotBalance `json:"cex"`
}

// dexBalance must be called with the balancesMtx locked.
// dexBalance must be called with the balancesMtx read locked.
func (u *unifiedExchangeAdaptor) dexBalance(assetID uint32) *BotBalance {
bal, found := u.baseDexBalances[assetID]
if !found {
Expand Down Expand Up @@ -1376,19 +1380,28 @@ func (u *unifiedExchangeAdaptor) refreshAllPendingEvents(ctx context.Context) {
}

for _, pendingDeposit := range pendingDeposits {
u.confirmDeposit(ctx, pendingDeposit.tx.ID)
pendingDeposit.mtx.RLock()
id := pendingDeposit.tx.ID
pendingDeposit.mtx.RUnlock()

u.confirmDeposit(ctx, id)
}

for _, pendingWithdrawal := range pendingWithdrawals {
u.confirmWithdrawal(ctx, pendingWithdrawal.withdrawalID)
}

for _, pendingOrder := range pendingCEXOrders {
trade, err := u.CEX.TradeStatus(ctx, pendingOrder.trade.ID, pendingOrder.trade.BaseID, pendingOrder.trade.QuoteID)
pendingOrder.tradeMtx.RLock()
id, baseID, quoteID := pendingOrder.trade.ID, pendingOrder.trade.BaseID, pendingOrder.trade.QuoteID
pendingOrder.tradeMtx.RUnlock()

trade, err := u.CEX.TradeStatus(ctx, id, baseID, quoteID)
if err != nil {
u.log.Errorf("error getting CEX trade status: %v", err)
continue
}

u.handleCEXTradeUpdate(trade)
}
}
Expand Down Expand Up @@ -1423,7 +1436,7 @@ func cexTradeBalanceEffects(trade *libxc.Trade) (effects *BalanceEffects) {
return
}

// cexBalance must be called with the balancesMtx locked.
// cexBalance must be called with the balancesMtx read locked.
func (u *unifiedExchangeAdaptor) cexBalance(assetID uint32) *BotBalance {
totalEffects := newBalanceEffects()
addEffects := func(effects *BalanceEffects) {
Expand All @@ -1442,7 +1455,11 @@ func (u *unifiedExchangeAdaptor) cexBalance(assetID uint32) *BotBalance {
}

for _, pendingOrder := range u.pendingCEXOrders {
addEffects(cexTradeBalanceEffects(pendingOrder.trade))
pendingOrder.tradeMtx.RLock()
trade := pendingOrder.trade
pendingOrder.tradeMtx.RUnlock()

addEffects(cexTradeBalanceEffects(trade))
}

for _, withdrawal := range u.pendingWithdrawals {
Expand Down Expand Up @@ -1568,7 +1585,10 @@ func (u *unifiedExchangeAdaptor) pendingDepositComplete(deposit *pendingDeposit)
}

u.sendStatsUpdate()

u.balancesMtx.RLock()
u.logBalanceAdjustments(dexDiffs, cexDiffs, msg)
u.balancesMtx.RUnlock()
}

func (u *unifiedExchangeAdaptor) confirmDeposit(ctx context.Context, txID string) bool {
Expand Down Expand Up @@ -1733,7 +1753,10 @@ func (u *unifiedExchangeAdaptor) pendingWithdrawalComplete(id string, tx *asset.

dexDiffs := map[uint32]int64{withdrawal.assetID: dexEffects.Settled[withdrawal.assetID]}
cexDiffs := map[uint32]int64{withdrawal.assetID: cexEffects.Settled[withdrawal.assetID]}

u.balancesMtx.RLock()
u.logBalanceAdjustments(dexDiffs, cexDiffs, fmt.Sprintf("Withdrawal %s complete.", id))
u.balancesMtx.RUnlock()
}

func (u *unifiedExchangeAdaptor) confirmWithdrawal(ctx context.Context, id string) bool {
Expand Down Expand Up @@ -1809,22 +1832,24 @@ func (u *unifiedExchangeAdaptor) withdraw(ctx context.Context, assetID uint32, a
u.balancesMtx.Unlock()
return err
}

u.log.Infof("Withdrew %s", u.fmtQty(assetID, amount))
if assetID == u.baseID {
u.pendingBaseRebalance.Store(true)
} else {
u.pendingQuoteRebalance.Store(true)
}
u.pendingWithdrawals[withdrawalID] = &pendingWithdrawal{
withdrawal := &pendingWithdrawal{
eventLogID: u.eventLogID.Add(1),
timestamp: time.Now().Unix(),
assetID: assetID,
amtWithdrawn: amount,
withdrawalID: withdrawalID,
}
u.pendingWithdrawals[withdrawalID] = withdrawal
u.balancesMtx.Unlock()

u.updateWithdrawalEvent(u.pendingWithdrawals[withdrawalID], nil)
u.updateWithdrawalEvent(withdrawal, nil)
u.sendStatsUpdate()

u.wg.Add(1)
Expand Down Expand Up @@ -1973,6 +1998,7 @@ func (u *unifiedExchangeAdaptor) handleCEXTradeUpdate(trade *libxc.Trade) {

u.balancesMtx.Lock()
defer u.balancesMtx.Unlock()

delete(u.pendingCEXOrders, trade.ID)

if trade.BaseFilled == 0 && trade.QuoteFilled == 0 {
Expand Down Expand Up @@ -2234,7 +2260,11 @@ func (u *unifiedExchangeAdaptor) cancelAllOrders(ctx context.Context) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

tradeStatus, err := u.CEX.TradeStatus(ctx, pendingOrder.trade.ID, pendingOrder.trade.BaseID, pendingOrder.trade.QuoteID)
pendingOrder.tradeMtx.RLock()
id, baseID, quoteID := pendingOrder.trade.ID, pendingOrder.trade.BaseID, pendingOrder.trade.QuoteID
pendingOrder.tradeMtx.RUnlock()

tradeStatus, err := u.CEX.TradeStatus(ctx, id, baseID, quoteID)
if err != nil {
u.log.Errorf("Error getting CEX trade status: %v", err)
continue
Expand All @@ -2244,9 +2274,9 @@ func (u *unifiedExchangeAdaptor) cancelAllOrders(ctx context.Context) {
}

done = false
err = u.CEX.CancelTrade(ctx, u.baseID, u.quoteID, pendingOrder.trade.ID)
err = u.CEX.CancelTrade(ctx, baseID, quoteID, id)
if err != nil {
u.log.Errorf("Error canceling CEX trade %s: %v", pendingOrder.trade.ID, err)
u.log.Errorf("Error canceling CEX trade %s: %v", id, err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion client/mm/libxc/binance.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (b *binanceOrderBook) sync(ctx context.Context) {
func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error /* no errors */) {
const updateIDUnsynced = math.MaxUint64

// We'll run two goroutines and sychronize two local vars.
// We'll run two goroutines and synchronize two local vars.
var syncMtx sync.Mutex
var syncCache []*bntypes.BookUpdate
syncChan := make(chan struct{})
Expand Down
3 changes: 3 additions & 0 deletions client/mm/libxc/orderbook.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ func (ob *orderbook) vwap(bids bool, qty uint64) (vwap, extrema uint64, filled b
}

func (ob *orderbook) midGap() uint64 {
ob.mtx.RLock()
defer ob.mtx.RUnlock()

bestBuyI := ob.bids.Front()
if bestBuyI == nil {
return 0
Expand Down
14 changes: 11 additions & 3 deletions client/mm/mm.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ func (m *MarketMaker) handleCEXUpdate(cexName string, ni interface{}) {
cex := m.cexes[cexName]
m.cexMtx.RUnlock()
if cex == nil {
m.log.Errorf("CEX update received from uknown cex %q?", cexName)
m.log.Errorf("CEX update received from unknown cex %q?", cexName)
return
}
cex.mtx.Lock()
Expand All @@ -481,10 +481,12 @@ func (m *MarketMaker) handleCEXUpdate(cexName string, ni interface{}) {
func (m *MarketMaker) cexList() []*centralizedExchange {
m.cexMtx.RLock()
defer m.cexMtx.RUnlock()

cexes := make([]*centralizedExchange, 0, len(m.cexes))
for _, cex := range m.cexes {
cexes = append(cexes, cex)
}

return cexes
}

Expand Down Expand Up @@ -517,13 +519,18 @@ func (m *MarketMaker) Connect(ctx context.Context) (*sync.WaitGroup, error) {
go func() {
defer wg.Done()
<-ctx.Done()
for _, cex := range m.cexList() {

m.cexMtx.Lock()
defer m.cexMtx.Unlock()

for _, cex := range m.cexes {
cex.mtx.RLock()
cm := cex.cm
cex.mtx.RUnlock()
if cm != nil {
cm.Disconnect()
}

delete(m.cexes, cex.Name)
}
}()
Expand Down Expand Up @@ -843,7 +850,8 @@ func (m *MarketMaker) UpdateCEXConfig(updatedCfg *CEXConfig) error {
m.defaultCfg.CexConfigs = append(m.defaultCfg.CexConfigs, updatedCfg)
}
m.defaultCfgMtx.Unlock()
if err := m.writeConfigFile(m.defaultCfg); err != nil {

if err := m.writeConfigFile(m.defaultConfig()); err != nil {
m.log.Errorf("Error saving new bot configuration: %w", err)
}

Expand Down

0 comments on commit 648ded5

Please sign in to comment.