Skip to content

Commit

Permalink
🏗 capital refactoring, switch to transport package
Browse files Browse the repository at this point in the history
  • Loading branch information
quenbyako committed Mar 24, 2021
1 parent 2e19c0c commit df86cba
Show file tree
Hide file tree
Showing 12 changed files with 444 additions and 352 deletions.
3 changes: 2 additions & 1 deletion handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ func (m *MTProto) makeAuthKey() error { // nolint don't know how to make method
hex.EncodeToString(dhg.NewNonceHash1.Bytes()),
)
}
m.serviceModeActivated = false

// (all ok)
m.serviceModeActivated = false
m.encrypted = true
err = m.SaveSession()
return errors.Wrap(err, "saving session")
}
14 changes: 8 additions & 6 deletions internal/transport/conn_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ type tcpConn struct {
timeout time.Duration
}

func NewTCP(host string, timeout time.Duration) (Conn, error) {
return NewTCPWithCtx(context.Background(), host, timeout)
type TCPConnConfig struct {
Ctx context.Context
Host string
Timeout time.Duration
}

func NewTCPWithCtx(ctx context.Context, host string, timeout time.Duration) (Conn, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", host)
func NewTCP(cfg TCPConnConfig) (Conn, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", cfg.Host)
if err != nil {
return nil, errors.Wrap(err, "resolving tcp")
}
Expand All @@ -31,9 +33,9 @@ func NewTCPWithCtx(ctx context.Context, host string, timeout time.Duration) (Con
}

return &tcpConn{
cancelReader: ioutil.NewCancelableReader(ctx, conn),
cancelReader: ioutil.NewCancelableReader(cfg.Ctx, conn),
conn: conn,
timeout: timeout,
timeout: cfg.Timeout,
}, nil
}

Expand Down
2 changes: 2 additions & 0 deletions internal/transport/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"io"
)

type ConnConfig interface{}
type Conn io.ReadWriteCloser

type ModeConfig = func(Conn) (Mode, error)
type Mode interface {
WriteMsg(msg []byte) error // this is not same as the io.Writer
ReadMsg() ([]byte, error)
Expand Down
2 changes: 1 addition & 1 deletion internal/transport/intermediate_mode.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type intermediateMode struct {

var transportModeIntermediate = [...]byte{0xee, 0xee, 0xee, 0xee} // meta:immutable

func NewIntermediateMode(conn io.ReadWriter) (Mode, error) {
func NewIntermediateMode(conn Conn) (Mode, error) {
if conn == nil {
return nil, errors.New("conn is nil")
}
Expand Down
15 changes: 12 additions & 3 deletions internal/transport/mode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package transport_test

import (
"bytes"
"io"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -11,7 +12,7 @@ import (
func TestMode(t *testing.T) {
buf := bytes.NewBuffer(nil)

tr, err := transport.NewIntermediateMode(buf)
tr, err := transport.NewIntermediateMode(DummyConn(buf))
require.NoError(t, err)

require.NoError(t, tr.WriteMsg([]byte("test message")))
Expand All @@ -21,13 +22,21 @@ func TestMode(t *testing.T) {
0x73, 0x61, 0x67, 0x65,
})

tr, err = transport.NewIntermediateMode(bytes.NewBuffer([]byte{
tr, err = transport.NewIntermediateMode(DummyConn(bytes.NewBuffer([]byte{
0x0c, 0x00, 0x00, 0x00, 0x74, 0x65, 0x73, 0x74,
0x20, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
}))
})))
require.NoError(t, err)

res, err := tr.ReadMsg()
require.NoError(t, err)
require.Equal(t, []byte("test message"), res)
}

func DummyConn(rw io.ReadWriter) transport.Conn { return nopCloser{ReadWriter: rw} }

type nopCloser struct {
io.ReadWriter
}

func (nopCloser) Close() error { return nil }
126 changes: 126 additions & 0 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package transport

import (
"context"
"encoding/binary"
"fmt"
"io"
"reflect"

"github.com/pkg/errors"
"github.com/xelaj/mtproto/internal/encoding/tl"
"github.com/xelaj/mtproto/internal/mtproto/messages"
)

type Transport interface {
Close() error
WriteMsg(msg messages.Common, requireToAck bool) error
ReadMsg() (messages.Common, error)
}

type transport struct {
conn Conn
mode Mode
m messages.MessageInformator
}

func NewTransport(m messages.MessageInformator, conn ConnConfig, mode ModeConfig) (Transport, error) {
t := &transport{
m: m,
}

var err error
switch cfg := conn.(type) {
case TCPConnConfig:
t.conn, err = NewTCP(cfg)
default:
return nil, fmt.Errorf("unsupported connection type %v", reflect.TypeOf(conn).String())
}
if err != nil {
return nil, errors.Wrap(err, "setup connection")
}

t.mode, err = mode(t.conn)
if err != nil {
return nil, errors.Wrap(err, "setup mode")
}

return t, nil
}

func (t *transport) Close() error {
return t.conn.Close()
}

func (t *transport) WriteMsg(msg messages.Common, requireToAck bool) error {
var data []byte
switch message := msg.(type) {
case *messages.Unencrypted:
data, _ = message.Serialize(t.m)

case *messages.Encrypted:
var err error
data, err = message.Serialize(t.m, requireToAck)
if err != nil {
return errors.Wrap(err, "serializing message")
}

default:
return fmt.Errorf("supported only mtproto predefined messages, got %v", reflect.TypeOf(msg).String())
}

err := t.mode.WriteMsg(data)
if err != nil {
return errors.Wrap(err, "sending request")
}
return nil
}

func (t *transport) ReadMsg() (messages.Common, error) {
data, err := t.mode.ReadMsg()
if err != nil {
switch err {
case io.EOF, context.Canceled:
return nil, err
default:
return nil, errors.Wrap(err, "reading message")
}
}

// checking that response is not error code
if len(data) == tl.WordLen {
code := int(binary.LittleEndian.Uint32(data))
return nil, ErrCode(code)
}

var msg messages.Common
if isPacketEncrypted(data) {
msg, err = messages.DeserializeEncrypted(data, t.m.GetAuthKey())
} else {
msg, err = messages.DeserializeUnencrypted(data)
}
if err != nil {
return nil, errors.Wrap(err, "parsing message")
}

mod := msg.GetMsgID() & 3 // why 3? only god knows why
if mod != 1 && mod != 3 {
return nil, fmt.Errorf("wrong bits of message_id: %d", mod)
}

return msg, nil
}

func isPacketEncrypted(data []byte) bool {
if len(data) < tl.DoubleLen {
return false
}
authKeyHash := data[:tl.DoubleLen]
return binary.LittleEndian.Uint64(authKeyHash) != 0
}

type ErrCode int

func (e ErrCode) Error() string {
return fmt.Sprintf("code %v", int(e))
}
97 changes: 97 additions & 0 deletions internal/utils/sync_stuff.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
package utils

import (
"reflect"
"sync"

"github.com/xelaj/mtproto/internal/encoding/tl"
)

type SyncSetInt struct {
Expand Down Expand Up @@ -46,3 +49,97 @@ func (s *SyncSetInt) Reset() {
defer s.mutex.Unlock()
s.m = make(map[int]null)
}

type SyncIntObjectChan struct {
mutex sync.RWMutex
m map[int]chan tl.Object
}

func NewSyncIntObjectChan() *SyncIntObjectChan {
return &SyncIntObjectChan{m: make(map[int]chan tl.Object)}
}

func (s *SyncIntObjectChan) Has(key int) bool {
s.mutex.RLock()
_, ok := s.m[key]
s.mutex.RUnlock()
return ok
}

func (s *SyncIntObjectChan) Get(key int) (chan tl.Object, bool) {
s.mutex.RLock()
defer s.mutex.RUnlock()
v, ok := s.m[key]
return v, ok
}

func (s *SyncIntObjectChan) Add(key int, value chan tl.Object) {
s.mutex.Lock()
s.m[key] = value
s.mutex.Unlock()
}

func (s *SyncIntObjectChan) Keys() []int {
keys := make([]int, 0, len(s.m))
s.mutex.RLock()
defer s.mutex.RUnlock()
for k := range s.m {
keys = append(keys, k)
}
return keys
}

func (s *SyncIntObjectChan) Delete(key int) bool {
s.mutex.Lock()
_, ok := s.m[key]
delete(s.m, key)
s.mutex.Unlock()
return ok
}

type SyncIntReflectTypes struct {
mutex sync.RWMutex
m map[int][]reflect.Type
}

func NewSyncIntReflectTypes() *SyncIntReflectTypes {
return &SyncIntReflectTypes{m: make(map[int][]reflect.Type)}
}

func (s *SyncIntReflectTypes) Has(key int) bool {
s.mutex.RLock()
_, ok := s.m[key]
s.mutex.RUnlock()
return ok
}

func (s *SyncIntReflectTypes) Get(key int) ([]reflect.Type, bool) {
s.mutex.RLock()
defer s.mutex.RUnlock()
v, ok := s.m[key]
return v, ok
}

func (s *SyncIntReflectTypes) Add(key int, value []reflect.Type) {
s.mutex.Lock()
s.m[key] = value
s.mutex.Unlock()
}

func (s *SyncIntReflectTypes) Keys() []int {
keys := make([]int, 0, len(s.m))
s.mutex.RLock()
defer s.mutex.RUnlock()
for k := range s.m {
keys = append(keys, k)
}
return keys
}

func (s *SyncIntReflectTypes) Delete(key int) bool {
s.mutex.Lock()
_, ok := s.m[key]
delete(s.m, key)
s.mutex.Unlock()
return ok
}
Loading

0 comments on commit df86cba

Please sign in to comment.