From 648ded5619289aa1bd966805b281d6723641d71b Mon Sep 17 00:00:00 2001 From: martonp Date: Tue, 23 Jul 2024 12:03:35 +0200 Subject: [PATCH] mm: Audit mutexes Fixes of mutex use in market making code. --- client/mm/exchange_adaptor.go | 56 +++++++++++++++++++++++++++-------- client/mm/libxc/binance.go | 2 +- client/mm/libxc/orderbook.go | 3 ++ client/mm/mm.go | 14 +++++++-- 4 files changed, 58 insertions(+), 17 deletions(-) diff --git a/client/mm/exchange_adaptor.go b/client/mm/exchange_adaptor.go index ad8bc00b75..5349fd0554 100644 --- a/client/mm/exchange_adaptor.go +++ b/client/mm/exchange_adaptor.go @@ -196,6 +196,7 @@ func withdrawalBalanceEffects(tx *asset.WalletTransaction, cexDebit uint64, asse func (w *pendingWithdrawal) balanceEffects() (dex, cex *BalanceEffects) { w.txMtx.RLock() defer w.txMtx.RUnlock() + return withdrawalBalanceEffects(w.tx, w.amtWithdrawn, w.assetID) } @@ -516,6 +517,10 @@ func (u *unifiedExchangeAdaptor) withPause(f func() error) error { return u.botLoop.ConnectOnce(u.ctx) } +// logBalanceAdjustments logs a trace log of balance adjustments and updated +// settled balances. +// +// balancesMtx must be read locked when calling this function. func (u *unifiedExchangeAdaptor) logBalanceAdjustments(dexDiffs, cexDiffs map[uint32]int64, reason string) { if u.log.Level() > dex.LevelTrace { return @@ -552,10 +557,10 @@ func (u *unifiedExchangeAdaptor) logBalanceAdjustments(dexDiffs, cexDiffs map[ui writeLine("Updated settled balances:") writeLine(" DEX:") + for assetID, bal := range u.baseDexBalances { writeLine(" " + format(assetID, bal, false)) } - if len(u.baseCexBalances) > 0 { writeLine(" CEX:") for assetID, bal := range u.baseCexBalances { @@ -690,14 +695,13 @@ type dexOrderInfo struct { // updateDEXOrderEvent updates the event log with the current state of a // pending DEX order and sends an event notification. func (u *unifiedExchangeAdaptor) updateDEXOrderEvent(o *pendingDEXOrder, complete bool) { + o.txsMtx.RLock() transactions := make([]*asset.WalletTransaction, 0, len(o.swaps)+len(o.redeems)+len(o.refunds)) addTxs := func(txs map[string]*asset.WalletTransaction) { for _, tx := range txs { transactions = append(transactions, tx) } } - - o.txsMtx.RLock() addTxs(o.swaps) addTxs(o.redeems) addTxs(o.refunds) @@ -1283,7 +1287,7 @@ type BotBalances struct { CEX *BotBalance `json:"cex"` } -// dexBalance must be called with the balancesMtx locked. +// dexBalance must be called with the balancesMtx read locked. func (u *unifiedExchangeAdaptor) dexBalance(assetID uint32) *BotBalance { bal, found := u.baseDexBalances[assetID] if !found { @@ -1376,7 +1380,11 @@ func (u *unifiedExchangeAdaptor) refreshAllPendingEvents(ctx context.Context) { } for _, pendingDeposit := range pendingDeposits { - u.confirmDeposit(ctx, pendingDeposit.tx.ID) + pendingDeposit.mtx.RLock() + id := pendingDeposit.tx.ID + pendingDeposit.mtx.RUnlock() + + u.confirmDeposit(ctx, id) } for _, pendingWithdrawal := range pendingWithdrawals { @@ -1384,11 +1392,16 @@ func (u *unifiedExchangeAdaptor) refreshAllPendingEvents(ctx context.Context) { } for _, pendingOrder := range pendingCEXOrders { - trade, err := u.CEX.TradeStatus(ctx, pendingOrder.trade.ID, pendingOrder.trade.BaseID, pendingOrder.trade.QuoteID) + pendingOrder.tradeMtx.RLock() + id, baseID, quoteID := pendingOrder.trade.ID, pendingOrder.trade.BaseID, pendingOrder.trade.QuoteID + pendingOrder.tradeMtx.RUnlock() + + trade, err := u.CEX.TradeStatus(ctx, id, baseID, quoteID) if err != nil { u.log.Errorf("error getting CEX trade status: %v", err) continue } + u.handleCEXTradeUpdate(trade) } } @@ -1423,7 +1436,7 @@ func cexTradeBalanceEffects(trade *libxc.Trade) (effects *BalanceEffects) { return } -// cexBalance must be called with the balancesMtx locked. +// cexBalance must be called with the balancesMtx read locked. func (u *unifiedExchangeAdaptor) cexBalance(assetID uint32) *BotBalance { totalEffects := newBalanceEffects() addEffects := func(effects *BalanceEffects) { @@ -1442,7 +1455,11 @@ func (u *unifiedExchangeAdaptor) cexBalance(assetID uint32) *BotBalance { } for _, pendingOrder := range u.pendingCEXOrders { - addEffects(cexTradeBalanceEffects(pendingOrder.trade)) + pendingOrder.tradeMtx.RLock() + trade := pendingOrder.trade + pendingOrder.tradeMtx.RUnlock() + + addEffects(cexTradeBalanceEffects(trade)) } for _, withdrawal := range u.pendingWithdrawals { @@ -1568,7 +1585,10 @@ func (u *unifiedExchangeAdaptor) pendingDepositComplete(deposit *pendingDeposit) } u.sendStatsUpdate() + + u.balancesMtx.RLock() u.logBalanceAdjustments(dexDiffs, cexDiffs, msg) + u.balancesMtx.RUnlock() } func (u *unifiedExchangeAdaptor) confirmDeposit(ctx context.Context, txID string) bool { @@ -1733,7 +1753,10 @@ func (u *unifiedExchangeAdaptor) pendingWithdrawalComplete(id string, tx *asset. dexDiffs := map[uint32]int64{withdrawal.assetID: dexEffects.Settled[withdrawal.assetID]} cexDiffs := map[uint32]int64{withdrawal.assetID: cexEffects.Settled[withdrawal.assetID]} + + u.balancesMtx.RLock() u.logBalanceAdjustments(dexDiffs, cexDiffs, fmt.Sprintf("Withdrawal %s complete.", id)) + u.balancesMtx.RUnlock() } func (u *unifiedExchangeAdaptor) confirmWithdrawal(ctx context.Context, id string) bool { @@ -1809,22 +1832,24 @@ func (u *unifiedExchangeAdaptor) withdraw(ctx context.Context, assetID uint32, a u.balancesMtx.Unlock() return err } + u.log.Infof("Withdrew %s", u.fmtQty(assetID, amount)) if assetID == u.baseID { u.pendingBaseRebalance.Store(true) } else { u.pendingQuoteRebalance.Store(true) } - u.pendingWithdrawals[withdrawalID] = &pendingWithdrawal{ + withdrawal := &pendingWithdrawal{ eventLogID: u.eventLogID.Add(1), timestamp: time.Now().Unix(), assetID: assetID, amtWithdrawn: amount, withdrawalID: withdrawalID, } + u.pendingWithdrawals[withdrawalID] = withdrawal u.balancesMtx.Unlock() - u.updateWithdrawalEvent(u.pendingWithdrawals[withdrawalID], nil) + u.updateWithdrawalEvent(withdrawal, nil) u.sendStatsUpdate() u.wg.Add(1) @@ -1973,6 +1998,7 @@ func (u *unifiedExchangeAdaptor) handleCEXTradeUpdate(trade *libxc.Trade) { u.balancesMtx.Lock() defer u.balancesMtx.Unlock() + delete(u.pendingCEXOrders, trade.ID) if trade.BaseFilled == 0 && trade.QuoteFilled == 0 { @@ -2234,7 +2260,11 @@ func (u *unifiedExchangeAdaptor) cancelAllOrders(ctx context.Context) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - tradeStatus, err := u.CEX.TradeStatus(ctx, pendingOrder.trade.ID, pendingOrder.trade.BaseID, pendingOrder.trade.QuoteID) + pendingOrder.tradeMtx.RLock() + id, baseID, quoteID := pendingOrder.trade.ID, pendingOrder.trade.BaseID, pendingOrder.trade.QuoteID + pendingOrder.tradeMtx.RUnlock() + + tradeStatus, err := u.CEX.TradeStatus(ctx, id, baseID, quoteID) if err != nil { u.log.Errorf("Error getting CEX trade status: %v", err) continue @@ -2244,9 +2274,9 @@ func (u *unifiedExchangeAdaptor) cancelAllOrders(ctx context.Context) { } done = false - err = u.CEX.CancelTrade(ctx, u.baseID, u.quoteID, pendingOrder.trade.ID) + err = u.CEX.CancelTrade(ctx, baseID, quoteID, id) if err != nil { - u.log.Errorf("Error canceling CEX trade %s: %v", pendingOrder.trade.ID, err) + u.log.Errorf("Error canceling CEX trade %s: %v", id, err) } } diff --git a/client/mm/libxc/binance.go b/client/mm/libxc/binance.go index a613217307..5464473bc0 100644 --- a/client/mm/libxc/binance.go +++ b/client/mm/libxc/binance.go @@ -151,7 +151,7 @@ func (b *binanceOrderBook) sync(ctx context.Context) { func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error /* no errors */) { const updateIDUnsynced = math.MaxUint64 - // We'll run two goroutines and sychronize two local vars. + // We'll run two goroutines and synchronize two local vars. var syncMtx sync.Mutex var syncCache []*bntypes.BookUpdate syncChan := make(chan struct{}) diff --git a/client/mm/libxc/orderbook.go b/client/mm/libxc/orderbook.go index 2deaeafe90..78eced2cf3 100644 --- a/client/mm/libxc/orderbook.go +++ b/client/mm/libxc/orderbook.go @@ -149,6 +149,9 @@ func (ob *orderbook) vwap(bids bool, qty uint64) (vwap, extrema uint64, filled b } func (ob *orderbook) midGap() uint64 { + ob.mtx.RLock() + defer ob.mtx.RUnlock() + bestBuyI := ob.bids.Front() if bestBuyI == nil { return 0 diff --git a/client/mm/mm.go b/client/mm/mm.go index 8122789bf0..2336ca7fd1 100644 --- a/client/mm/mm.go +++ b/client/mm/mm.go @@ -467,7 +467,7 @@ func (m *MarketMaker) handleCEXUpdate(cexName string, ni interface{}) { cex := m.cexes[cexName] m.cexMtx.RUnlock() if cex == nil { - m.log.Errorf("CEX update received from uknown cex %q?", cexName) + m.log.Errorf("CEX update received from unknown cex %q?", cexName) return } cex.mtx.Lock() @@ -481,10 +481,12 @@ func (m *MarketMaker) handleCEXUpdate(cexName string, ni interface{}) { func (m *MarketMaker) cexList() []*centralizedExchange { m.cexMtx.RLock() defer m.cexMtx.RUnlock() + cexes := make([]*centralizedExchange, 0, len(m.cexes)) for _, cex := range m.cexes { cexes = append(cexes, cex) } + return cexes } @@ -517,13 +519,18 @@ func (m *MarketMaker) Connect(ctx context.Context) (*sync.WaitGroup, error) { go func() { defer wg.Done() <-ctx.Done() - for _, cex := range m.cexList() { + + m.cexMtx.Lock() + defer m.cexMtx.Unlock() + + for _, cex := range m.cexes { cex.mtx.RLock() cm := cex.cm cex.mtx.RUnlock() if cm != nil { cm.Disconnect() } + delete(m.cexes, cex.Name) } }() @@ -843,7 +850,8 @@ func (m *MarketMaker) UpdateCEXConfig(updatedCfg *CEXConfig) error { m.defaultCfg.CexConfigs = append(m.defaultCfg.CexConfigs, updatedCfg) } m.defaultCfgMtx.Unlock() - if err := m.writeConfigFile(m.defaultCfg); err != nil { + + if err := m.writeConfigFile(m.defaultConfig()); err != nil { m.log.Errorf("Error saving new bot configuration: %w", err) }