diff --git a/internal/transport/conn_tcp.go b/internal/transport/conn_tcp.go index e261853..e67947d 100644 --- a/internal/transport/conn_tcp.go +++ b/internal/transport/conn_tcp.go @@ -17,7 +17,7 @@ type tcpConn struct { } func NewTCP(host string, timeout time.Duration) (Conn, error) { - return NewTCPWithCtx(nil, host, timeout) + return NewTCPWithCtx(context.Background(), host, timeout) } func NewTCPWithCtx(ctx context.Context, host string, timeout time.Duration) (Conn, error) { @@ -30,10 +30,6 @@ func NewTCPWithCtx(ctx context.Context, host string, timeout time.Duration) (Con return nil, errors.Wrap(err, "dialing tcp") } - if ctx == nil { - ctx = context.Background() - } - return &tcpConn{ cancelReader: ioutil.NewCancelableReader(ctx, conn), conn: conn, diff --git a/internal/transport/interfaces.go b/internal/transport/interfaces.go index 4ed3ca5..2d73282 100644 --- a/internal/transport/interfaces.go +++ b/internal/transport/interfaces.go @@ -6,7 +6,7 @@ import ( type Conn io.ReadWriteCloser -type Transport interface { +type Mode interface { WriteMsg(msg []byte) error // this is not same as the io.Writer ReadMsg() ([]byte, error) } diff --git a/internal/transport/intermediate_mode.go b/internal/transport/intermediate_mode.go index 4a518dc..ab40c64 100644 --- a/internal/transport/intermediate_mode.go +++ b/internal/transport/intermediate_mode.go @@ -15,7 +15,7 @@ type intermediateMode struct { var transportModeIntermediate = [...]byte{0xee, 0xee, 0xee, 0xee} // meta:immutable -func NewIntermediateMode(conn io.ReadWriter) (Transport, error) { +func NewIntermediateMode(conn io.ReadWriter) (Mode, error) { if conn == nil { return nil, errors.New("conn is nil") } diff --git a/internal/transport/mode_test.go b/internal/transport/mode_test.go new file mode 100644 index 0000000..98b6c7b --- /dev/null +++ b/internal/transport/mode_test.go @@ -0,0 +1,33 @@ +package transport_test + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + "github.com/xelaj/mtproto/internal/transport" +) + +func TestMode(t *testing.T) { + buf := bytes.NewBuffer(nil) + + tr, err := transport.NewIntermediateMode(buf) + require.NoError(t, err) + + require.NoError(t, tr.WriteMsg([]byte("test message"))) + require.Equal(t, buf.Bytes(), []byte{ + 0xee, 0xee, 0xee, 0xee, 0x0c, 0x00, 0x00, 0x00, + 0x74, 0x65, 0x73, 0x74, 0x20, 0x6d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, + }) + + tr, err = transport.NewIntermediateMode(bytes.NewBuffer([]byte{ + 0x0c, 0x00, 0x00, 0x00, 0x74, 0x65, 0x73, 0x74, + 0x20, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + })) + require.NoError(t, err) + + res, err := tr.ReadMsg() + require.NoError(t, err) + require.Equal(t, []byte("test message"), res) +} diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go deleted file mode 100644 index bec323d..0000000 --- a/internal/transport/transport_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package transport_test - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/xelaj/mtproto/internal/transport" -) - -func TestMode(t *testing.T) { - buf := bytes.NewBuffer(nil) - - transport, err := transport.NewIntermediateMode(buf) - require.NoError(t, err) - - require.NoError(t, transport.WriteMsg([]byte("test message"))) - assert.Equal(t, buf.Bytes(), []byte{ - 0xee, 0xee, 0xee, 0xee, 0x0c, 0x00, 0x00, 0x00, - 0x74, 0x65, 0x73, 0x74, 0x20, 0x6d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, - }) -} diff --git a/network.go b/network.go index 02df760..b297202 100644 --- a/network.go +++ b/network.go @@ -8,7 +8,6 @@ package mtproto import ( "encoding/binary" "fmt" - "time" "github.com/pkg/errors" @@ -17,11 +16,6 @@ import ( "github.com/xelaj/mtproto/internal/mtproto/objects" ) -const ( - // если длина пакета больше или равн 127 слов, то кодируем 4 байтами, 1 это магическое число, оставшиеся 3 — дилна - magicValueSizeMoreThanSingleByte = 0x7f -) - // проверяет, надо ли ждать от сервера пинга func isNullableResponse(t tl.Object) bool { switch t.(type) { @@ -32,12 +26,8 @@ func isNullableResponse(t tl.Object) bool { } } -const ( - readTimeout = 2 * time.Second -) - func CatchResponseErrorCode(data []byte) error { - if len(data) == 4 { + if len(data) == tl.WordLen { code := int(binary.LittleEndian.Uint32(data)) return &ErrResponseCode{Code: code} } @@ -70,11 +60,9 @@ func (m *MTProto) decodeRecievedData(data []byte) (messages.Common, error) { return nil, errors.Wrap(err, "parsing message") } - m.msgId = int64(msg.GetMsgID()) - m.seqNo = int32(msg.GetSeqNo()) - mod := m.msgId & 3 + mod := msg.GetMsgID() & 3 if mod != 1 && mod != 3 { - return nil, fmt.Errorf("Wrong bits of message_id: %d", mod) + return nil, fmt.Errorf("wrong bits of message_id: %d", mod) } return msg, nil diff --git a/utils.go b/utils.go index c5671b4..b3fcd7e 100644 --- a/utils.go +++ b/utils.go @@ -6,6 +6,9 @@ package mtproto import ( + "context" + "io" + "github.com/xelaj/mtproto/internal/encoding/tl" "github.com/xelaj/mtproto/internal/mtproto/objects" ) @@ -23,14 +26,6 @@ func defaultDCList() map[int]string { } } -// https://core.telegram.org/mtproto/mtproto-transports -var ( - transportModeAbridged = [...]byte{0xef} // meta:immutable - transportModeIntermediate = [...]byte{0xee, 0xee, 0xee, 0xee} // meta:immutable - transportModePaddedIntermediate = [...]byte{0xdd, 0xdd, 0xdd, 0xdd} // meta:immutable - transportModeFull = [...]byte{} // meta:immutable -) - func MessageRequireToAck(msg tl.Object) bool { switch msg.(type) { case /**objects.Ping,*/ *objects.MsgsAck: @@ -39,3 +34,10 @@ func MessageRequireToAck(msg tl.Object) bool { return true } } + +func CloseOnCancel(ctx context.Context, c io.Closer) { + go func() { + <-ctx.Done() + c.Close() + }() +}