Skip to content

Commit

Permalink
server: allow ServerSession.WritePacket*() to return an error
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 committed Aug 14, 2023
1 parent 2897122 commit 482b047
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 76 deletions.
30 changes: 28 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1678,13 +1678,39 @@ func (c *Client) WritePacketRTP(medi *media.Media, pkt *rtp.Packet) error {

// WritePacketRTPWithNTP writes a RTP packet to the media stream.
func (c *Client) WritePacketRTPWithNTP(medi *media.Media, pkt *rtp.Packet, ntp time.Time) error {
byts := make([]byte, udpMaxPayloadSize)
n, err := pkt.MarshalTo(byts)
if err != nil {
return err
}
byts = byts[:n]

select {
case <-c.done:
return c.closeError
default:
}

cm := c.medias[medi]
ct := cm.formats[pkt.PayloadType]
return ct.writePacketRTPWithNTP(pkt, ntp)
ct.writePacketRTP(byts, pkt, ntp)
return nil
}

// WritePacketRTCP writes a RTCP packet to the media stream.
func (c *Client) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) error {
byts, err := pkt.Marshal()
if err != nil {
return err
}

select {
case <-c.done:
return c.closeError
default:
}

cm := c.medias[medi]
return cm.writePacketRTCP(pkt)
cm.writePacketRTCP(byts)
return nil
}
34 changes: 9 additions & 25 deletions client_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
)

type clientFormat struct {
c *Client
cm *clientMedia
format formats.Format
udpReorderer *rtpreorderer.Reorderer // play
Expand All @@ -27,7 +26,6 @@ type clientFormat struct {

func newClientFormat(cm *clientMedia, forma formats.Format) *clientFormat {
return &clientFormat{
c: cm.c,
cm: cm,
format: forma,
onPacketRTP: func(*rtp.Packet) {},
Expand All @@ -44,7 +42,7 @@ func (ct *clientFormat) start() {
ct.cm.c.udpReceiverReportPeriod,
nil,
ct.format.ClockRate(), func(pkt rtcp.Packet) {
ct.cm.writePacketRTCP(pkt) //nolint:errcheck
ct.cm.c.WritePacketRTCP(ct.cm.media, pkt) //nolint:errcheck
})
if err != nil {
panic(err)
Expand All @@ -56,15 +54,15 @@ func (ct *clientFormat) start() {
ct.rtcpSender = rtcpsender.New(
ct.format.ClockRate(),
func(pkt rtcp.Packet) {
ct.cm.writePacketRTCP(pkt) //nolint:errcheck
ct.cm.c.WritePacketRTCP(ct.cm.media, pkt) //nolint:errcheck
})
}
}

// start writing after write*() has been allocated in order to avoid a crash
func (ct *clientFormat) startWriting() {
if ct.c.state != clientStatePlay && !ct.c.DisableRTCPSenderReports {
ct.rtcpSender.Start(ct.c.senderReportPeriod)
if ct.cm.c.state != clientStatePlay && !ct.cm.c.DisableRTCPSenderReports {
ct.rtcpSender.Start(ct.cm.c.senderReportPeriod)
}
}

Expand All @@ -79,32 +77,18 @@ func (ct *clientFormat) stop() {
}
}

func (ct *clientFormat) writePacketRTPWithNTP(pkt *rtp.Packet, ntp time.Time) error {
byts := make([]byte, udpMaxPayloadSize)
n, err := pkt.MarshalTo(byts)
if err != nil {
return err
}
byts = byts[:n]

select {
case <-ct.c.done:
return ct.c.closeError
default:
}
func (ct *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) {
ct.rtcpSender.ProcessPacket(pkt, ntp, ct.format.PTSEqualsDTS(pkt))

ct.c.writer.queue(func() {
ct.cm.c.writer.queue(func() {
ct.cm.writePacketRTPInQueue(byts)
})

ct.rtcpSender.ProcessPacket(pkt, ntp, ct.format.PTSEqualsDTS(pkt))
return nil
}

func (ct *clientFormat) readRTPUDP(pkt *rtp.Packet) {
packets, lost := ct.udpReorderer.Process(pkt)
if lost != 0 {
ct.c.OnPacketLost(fmt.Errorf("%d RTP %s lost",
ct.cm.c.OnPacketLost(fmt.Errorf("%d RTP %s lost",
lost,
func() string {
if lost == 1 {
Expand All @@ -126,7 +110,7 @@ func (ct *clientFormat) readRTPUDP(pkt *rtp.Packet) {
func (ct *clientFormat) readRTPTCP(pkt *rtp.Packet) {
lost := ct.tcpLossDetector.Process(pkt)
if lost != 0 {
ct.c.OnPacketLost(fmt.Errorf("%d RTP %s lost",
ct.cm.c.OnPacketLost(fmt.Errorf("%d RTP %s lost",
lost,
func() string {
if lost == 1 {
Expand Down
15 changes: 1 addition & 14 deletions client_media.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,23 +180,10 @@ func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) {
cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer) //nolint:errcheck
}

func (cm *clientMedia) writePacketRTCP(pkt rtcp.Packet) error {
byts, err := pkt.Marshal()
if err != nil {
return err
}

select {
case <-cm.c.done:
return cm.c.closeError
default:
}

func (cm *clientMedia) writePacketRTCP(byts []byte) {
cm.c.writer.queue(func() {
cm.writePacketRTCPInQueue(byts)
})

return nil
}

func (cm *clientMedia) readRTPTCPPlay(payload []byte) {
Expand Down
3 changes: 2 additions & 1 deletion server_play_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,8 @@ func TestServerPlay(t *testing.T) {
// send RTCP packets directly to the session.
// these are sent after the response, only if onPlay returns StatusOK.
if transport != "multicast" {
ctx.Session.WritePacketRTCP(stream.Medias()[0], &testRTCPPacket)
err := ctx.Session.WritePacketRTCP(stream.Medias()[0], &testRTCPPacket)
require.NoError(t, err)
}

ctx.Session.OnPacketRTCPAny(func(medi *media.Media, pkt rtcp.Packet) {
Expand Down
9 changes: 6 additions & 3 deletions server_record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,10 @@ func TestServerRecord(t *testing.T) {
onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) {
// send RTCP packets directly to the session.
// these are sent after the response, only if onRecord returns StatusOK.
ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket)
ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[1], &testRTCPPacket)
err := ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket)
require.NoError(t, err)
err = ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[1], &testRTCPPacket)
require.NoError(t, err)

for i := 0; i < 2; i++ {
ctx.Session.OnPacketRTP(
Expand All @@ -538,7 +540,8 @@ func TestServerRecord(t *testing.T) {
ctx.Session.AnnouncedMedias()[i],
func(pkt rtcp.Packet) {
require.Equal(t, &testRTCPPacket, pkt)
ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[ci], &testRTCPPacket)
err := ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[ci], &testRTCPPacket)
require.NoError(t, err)
})
}

Expand Down
10 changes: 6 additions & 4 deletions server_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1172,13 +1172,14 @@ func (ss *ServerSession) writePacketRTP(medi *media.Media, byts []byte) {
}

// WritePacketRTP writes a RTP packet to the session.
func (ss *ServerSession) WritePacketRTP(medi *media.Media, pkt *rtp.Packet) {
func (ss *ServerSession) WritePacketRTP(medi *media.Media, pkt *rtp.Packet) error {
byts, err := pkt.Marshal()
if err != nil {
return
return err
}

ss.writePacketRTP(medi, byts)
return nil
}

func (ss *ServerSession) writePacketRTCP(medi *media.Media, byts []byte) {
Expand All @@ -1187,13 +1188,14 @@ func (ss *ServerSession) writePacketRTCP(medi *media.Media, byts []byte) {
}

// WritePacketRTCP writes a RTCP packet to the session.
func (ss *ServerSession) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) {
func (ss *ServerSession) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) error {
byts, err := pkt.Marshal()
if err != nil {
return
return err
}

ss.writePacketRTCP(medi, byts)
return nil
}

func (ss *ServerSession) handleRequest(req sessionRequestReq) (*base.Response, *ServerSession, error) {
Expand Down
2 changes: 1 addition & 1 deletion server_session_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (sf *serverSessionFormat) start() {
nil,
sf.format.ClockRate(),
func(pkt rtcp.Packet) {
sf.sm.ss.WritePacketRTCP(sf.sm.media, pkt)
sf.sm.ss.WritePacketRTCP(sf.sm.media, pkt) //nolint:errcheck
})
if err != nil {
panic(err)
Expand Down
4 changes: 2 additions & 2 deletions server_session_media.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ func (sm *serverSessionMedia) start() {
sm.ss.s.udpRTCPListener.addClient(sm.ss.author.ip(), sm.udpRTCPReadPort, sm.readRTCPUDPPlay)
} else {
// open the firewall by sending empty packets to the counterpart.
sm.ss.WritePacketRTP(sm.media, &rtp.Packet{Header: rtp.Header{Version: 2}})
sm.ss.WritePacketRTCP(sm.media, &rtcp.ReceiverReport{})
sm.ss.WritePacketRTP(sm.media, &rtp.Packet{Header: rtp.Header{Version: 2}}) //nolint:errcheck
sm.ss.WritePacketRTCP(sm.media, &rtcp.ReceiverReport{}) //nolint:errcheck

sm.ss.s.udpRTPListener.addClient(sm.ss.author.ip(), sm.udpRTPReadPort, sm.readRTPUDPRecord)
sm.ss.s.udpRTCPListener.addClient(sm.ss.author.ip(), sm.udpRTCPReadPort, sm.readRTCPUDPRecord)
Expand Down
18 changes: 16 additions & 2 deletions server_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,13 @@ func (st *ServerStream) WritePacketRTP(medi *media.Media, pkt *rtp.Packet) error
// ntp is the absolute time of the packet, and is needed to generate RTCP sender reports
// that allows the receiver to reconstruct the absolute time of the packet.
func (st *ServerStream) WritePacketRTPWithNTP(medi *media.Media, pkt *rtp.Packet, ntp time.Time) error {
byts := make([]byte, udpMaxPayloadSize)
n, err := pkt.MarshalTo(byts)
if err != nil {
return err
}
byts = byts[:n]

st.mutex.RLock()
defer st.mutex.RUnlock()

Expand All @@ -268,11 +275,17 @@ func (st *ServerStream) WritePacketRTPWithNTP(medi *media.Media, pkt *rtp.Packet
}

sm := st.streamMedias[medi]
return sm.writePacketRTPWithNTP(st, pkt, ntp)
sm.writePacketRTP(byts, pkt, ntp)
return nil
}

// WritePacketRTCP writes a RTCP packet to all the readers of the stream.
func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) error {
byts, err := pkt.Marshal()
if err != nil {
return err
}

st.mutex.RLock()
defer st.mutex.RUnlock()

Expand All @@ -281,5 +294,6 @@ func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) erro
}

sm := st.streamMedias[medi]
return sm.writePacketRTCP(st, pkt)
sm.writePacketRTCP(byts)
return nil
}
30 changes: 8 additions & 22 deletions server_stream_media.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ import (
)

type serverStreamMedia struct {
trackID int
st *ServerStream
media *media.Media
trackID int
formats map[uint8]*serverStreamFormat
multicastWriter *serverMulticastWriter
}

func newServerStreamMedia(st *ServerStream, medi *media.Media, trackID int) *serverStreamMedia {
sm := &serverStreamMedia{
trackID: trackID,
st: st,
media: medi,
trackID: trackID,
}

sm.formats = make(map[uint8]*serverStreamFormat)
Expand Down Expand Up @@ -67,20 +69,13 @@ func (sm *serverStreamMedia) allocateMulticastHandler(s *Server) error {
return nil
}

func (sm *serverStreamMedia) writePacketRTPWithNTP(ss *ServerStream, pkt *rtp.Packet, ntp time.Time) error {
byts := make([]byte, udpMaxPayloadSize)
n, err := pkt.MarshalTo(byts)
if err != nil {
return err
}
byts = byts[:n]

func (sm *serverStreamMedia) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) {
forma := sm.formats[pkt.PayloadType]

forma.rtcpSender.ProcessPacket(pkt, ntp, forma.format.PTSEqualsDTS(pkt))

// send unicast
for r := range ss.activeUnicastReaders {
for r := range sm.st.activeUnicastReaders {
sm, ok := r.setuppedMedias[sm.media]
if ok {
sm.writePacketRTP(byts)
Expand All @@ -91,18 +86,11 @@ func (sm *serverStreamMedia) writePacketRTPWithNTP(ss *ServerStream, pkt *rtp.Pa
if sm.multicastWriter != nil {
sm.multicastWriter.writePacketRTP(byts)
}

return nil
}

func (sm *serverStreamMedia) writePacketRTCP(ss *ServerStream, pkt rtcp.Packet) error {
byts, err := pkt.Marshal()
if err != nil {
return err
}

func (sm *serverStreamMedia) writePacketRTCP(byts []byte) {
// send unicast
for r := range ss.activeUnicastReaders {
for r := range sm.st.activeUnicastReaders {
sm, ok := r.setuppedMedias[sm.media]
if ok {
sm.writePacketRTCP(byts)
Expand All @@ -113,6 +101,4 @@ func (sm *serverStreamMedia) writePacketRTCP(ss *ServerStream, pkt rtcp.Packet)
if sm.multicastWriter != nil {
sm.multicastWriter.writePacketRTCP(byts)
}

return nil
}

0 comments on commit 482b047

Please sign in to comment.