Skip to content

Commit

Permalink
Expose server's middleware for incoming clients control
Browse files Browse the repository at this point in the history
  • Loading branch information
Waldz committed Apr 27, 2020
1 parent 9a181e8 commit 7e55222
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 212 deletions.
134 changes: 66 additions & 68 deletions openvpn/middlewares/server/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,79 @@ package auth

import (
"strings"
"sync"

"github.com/mysteriumnetwork/go-openvpn/openvpn/log"
"github.com/mysteriumnetwork/go-openvpn/openvpn/management"
"github.com/mysteriumnetwork/go-openvpn/openvpn/middlewares/server"
)

type middleware struct {
// TODO: consider implementing event channel to communicate required callbacks
credentialsValidator CredentialsValidator
commandWriter management.CommandWriter
currentEvent server.ClientEvent
}
type ClientEventCallback func(event server.ClientEvent)

// Middleware is able to process client auth events, exposes client control API.
//
// The OpenVPN server should have been started with the
// --management-client-auth directive so that it will ask the management
// interface to approve client connections.
type Middleware struct {
commandWriter management.CommandWriter
currentEvent server.ClientEvent

// CredentialsValidator callback checks given auth primitives (i.e. customer identity signature / node's sessionId)
type CredentialsValidator func(clientID int, username, password string) (bool, error)
listenersMu sync.RWMutex
listeners []ClientEventCallback
}

// NewMiddleware creates server user_auth challenge authentication middleware
func NewMiddleware(credentialsValidator CredentialsValidator) *middleware {
return &middleware{
credentialsValidator: credentialsValidator,
commandWriter: nil,
currentEvent: server.UndefinedEvent,
// NewMiddleware creates new instance of Middleware
func NewMiddleware(listeners ...ClientEventCallback) *Middleware {
return &Middleware{
commandWriter: nil,
currentEvent: server.UndefinedEvent(),
listeners: listeners,
}
}

func (m *middleware) Start(commandWriter management.CommandWriter) error {
func (m *Middleware) ClientsSubscribe(callback ClientEventCallback) {
m.listenersMu.Lock()
defer m.listenersMu.Unlock()

m.listeners = append(m.listeners, callback)
}

func (m *Middleware) ClientAccept(clientID, keyID int) error {
_, err := m.commandWriter.SingleLineCommand("client-auth-nt %d %d", clientID, keyID)
return err
}

func (m *Middleware) ClientDeny(clientID, keyID int, message string) error {
_, err := m.commandWriter.SingleLineCommand("client-deny %d %d", clientID, keyID, message)
return err
}

func (m *Middleware) ClientDenyWithMessage(clientID, keyID int, message string) error {
_, err := m.commandWriter.SingleLineCommand("client-deny %d %d %s", clientID, keyID, message)
return err
}

func (m *Middleware) ClientKill(clientID int) error {
_, err := m.commandWriter.SingleLineCommand("client-kill %d", clientID)
return err
}

func (m *Middleware) ClientKillWithMessage(clientID int, message string) error {
_, err := m.commandWriter.SingleLineCommand("client-kill %d %s", clientID, message)
return err
}

func (m *Middleware) Start(commandWriter management.CommandWriter) error {
m.commandWriter = commandWriter
return nil
}

func (m *middleware) Stop(commandWriter management.CommandWriter) error {
func (m *Middleware) Stop(_ management.CommandWriter) error {
return nil
}

func (m *middleware) ConsumeLine(line string) (bool, error) {
func (m *Middleware) ConsumeLine(line string) (bool, error) {
if !strings.HasPrefix(line, ">CLIENT:") {
return false, nil
}
Expand Down Expand Up @@ -98,69 +136,29 @@ func (m *middleware) ConsumeLine(line string) (bool, error) {
return true, nil
}

func (m *middleware) startOfEvent(eventType server.ClientEventType, clientID int, keyID int) {
func (m *Middleware) startOfEvent(eventType server.ClientEventType, clientID int, keyID int) {
m.currentEvent.EventType = eventType
m.currentEvent.ClientID = clientID
m.currentEvent.ClientKey = keyID
}

func (m *middleware) addEnvVar(key string, val string) {
func (m *Middleware) addEnvVar(key string, val string) {
m.currentEvent.Env[key] = val
}

func (m *middleware) endOfEvent() {
m.handleClientEvent(m.currentEvent)
m.reset()
}
func (m *Middleware) endOfEvent() {
m.listenersMu.RLock()
defer m.listenersMu.RUnlock()

func (m *middleware) reset() {
m.currentEvent = server.UndefinedEvent
}

func (m *middleware) handleClientEvent(event server.ClientEvent) {
switch event.EventType {
case server.Connect, server.Reauth:
username := event.Env["username"]
password := event.Env["password"]
err := m.authenticateClient(event.ClientID, event.ClientKey, username, password)
if err != nil {
log.Error("Unable to authenticate client:", err)
if m.listeners != nil {
for _, subscription := range m.listeners {
subscription(m.currentEvent)
}
case server.Established:
log.Info("Client with ID:", event.ClientID, "connection established successfully")
case server.Disconnect:
log.Info("Client with ID:", event.ClientID, "disconnected")
// NOTE: do not cleanup session after disconnect event risen by transport itself
// cleanup session only by user's intent
}
}

func (m *middleware) authenticateClient(clientID, clientKey int, username, password string) error {

if username == "" || password == "" {
return denyClientAuthWithMessage(m.commandWriter, clientID, clientKey, "missing username or password")
}

log.Info("Authenticating user:", username, "clientID:", clientID, "clientKey:", clientKey)

authenticated, err := m.credentialsValidator(clientID, username, password)
if err != nil {
log.Error("Authentication error:", err)
return denyClientAuthWithMessage(m.commandWriter, clientID, clientKey, "internal error")
}

if authenticated {
return approveClient(m.commandWriter, clientID, clientKey)
}
return denyClientAuthWithMessage(m.commandWriter, clientID, clientKey, "wrong username or password")
}

func approveClient(commandWriter management.CommandWriter, clientID, keyID int) error {
_, err := commandWriter.SingleLineCommand("client-auth-nt %d %d", clientID, keyID)
return err
m.reset()
}

func denyClientAuthWithMessage(commandWriter management.CommandWriter, clientID, keyID int, message string) error {
_, err := commandWriter.SingleLineCommand("client-deny %d %d %s", clientID, keyID, message)
return err
func (m *Middleware) reset() {
m.currentEvent = server.UndefinedEvent()
}
127 changes: 56 additions & 71 deletions openvpn/middlewares/server/auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,10 @@ import (
"testing"

"github.com/mysteriumnetwork/go-openvpn/openvpn/management"
"github.com/mysteriumnetwork/go-openvpn/openvpn/middlewares/server"
"github.com/stretchr/testify/assert"
)

type fakeAuthenticatorStub struct {
called bool
username string
password string
authenticated bool
}

func (f *fakeAuthenticatorStub) fakeAuthenticator(clientID int, username, password string) (bool, error) {
f.called = true
f.username = username
f.password = password
return f.authenticated, nil
}

func Test_Factory(t *testing.T) {
fas := fakeAuthenticatorStub{}
middleware := NewMiddleware(fas.fakeAuthenticator)
assert.NotNil(t, middleware)
}

func Test_ConsumeLineSkips(t *testing.T) {
var tests = []struct {
line string
Expand All @@ -53,8 +34,7 @@ func Test_ConsumeLineSkips(t *testing.T) {
{">PASSWORD"},
{">USERNAME"},
}
fas := fakeAuthenticatorStub{}
middleware := NewMiddleware(fas.fakeAuthenticator)
middleware := NewMiddleware()

for _, test := range tests {
consumed, err := middleware.ConsumeLine(test.line)
Expand All @@ -73,8 +53,7 @@ func Test_ConsumeLineTakes(t *testing.T) {
{">CLIENT:ENV,username=username"},
}

fas := fakeAuthenticatorStub{}
middleware := NewMiddleware(fas.fakeAuthenticator)
middleware := NewMiddleware()
mockConnection := &management.MockConnection{}
middleware.Start(mockConnection)

Expand All @@ -85,48 +64,33 @@ func Test_ConsumeLineTakes(t *testing.T) {
}
}

func Test_ConsumeLineAuthState(t *testing.T) {
func Test_ConsumeLineShouldNotTriggerClientState(t *testing.T) {
var tests = []struct {
line string
}{
{">CLIENT:ENV,password=12341234"},
{">CLIENT:ENV,username=username"},
{">CLIENT:REAUTH,0,0"},
{">CLIENT:CONNECT,0,0"},
}

for _, test := range tests {
fas := fakeAuthenticatorStub{}
middleware := NewMiddleware(fas.fakeAuthenticator)
mockConnection := &management.MockConnection{}
middleware.Start(mockConnection)

consumed, err := middleware.ConsumeLine(test.line)
assert.NoError(t, err, test.line)
assert.True(t, consumed, test.line)
}
}

func Test_ConsumeLineNotAuthState(t *testing.T) {
var tests = []struct {
line string
}{
{">CLIENT:ENV,password=12341234"},
{">CLIENT:ENV,username=username"},
}
var receivedEvent *server.ClientEvent
middleware := NewMiddleware(func(e server.ClientEvent) {
receivedEvent = &e
})

for _, test := range tests {
fas := fakeAuthenticatorStub{}
middleware := NewMiddleware(fas.fakeAuthenticator)
mockConnection := &management.MockConnection{}
middleware.Start(mockConnection)

consumed, err := middleware.ConsumeLine(test.line)
assert.NoError(t, err, test.line)
assert.True(t, consumed, test.line)
assert.False(t, fas.called)
assert.Nil(t, receivedEvent)
}
}

func Test_ConsumeLineAuthTrueChecker(t *testing.T) {
func Test_ConsumeLineShouldTriggerClientStateAfterReceivingEnvironment(t *testing.T) {
var tests = []struct {
line string
}{
Expand All @@ -135,9 +99,12 @@ func Test_ConsumeLineAuthTrueChecker(t *testing.T) {
{">CLIENT:ENV,username=username1"},
{">CLIENT:ENV,END"},
}
fas := fakeAuthenticatorStub{}
fas.authenticated = true
middleware := NewMiddleware(fas.fakeAuthenticator)

var receivedEvent server.ClientEvent
middleware := NewMiddleware(func(e server.ClientEvent) {
receivedEvent = e
})

mockConnection := &management.MockConnection{}
middleware.Start(mockConnection)

Expand All @@ -146,31 +113,49 @@ func Test_ConsumeLineAuthTrueChecker(t *testing.T) {
assert.NoError(t, err, test.line)
assert.True(t, consumed, test.line)
}
assert.True(t, fas.called)
assert.Equal(t, "username1", fas.username)
assert.Equal(t, "12341234", fas.password)
assert.Equal(t, "client-auth-nt 1 2", mockConnection.LastLine)
assert.Equal(
t,
server.ClientEvent{
EventType: server.Connect,
ClientID: 1,
ClientKey: 2,
Env: map[string]string{
"username": "username1",
"password": "12341234",
},
},
receivedEvent,
)
}

func Test_ConsumeLineAuthFalseChecker(t *testing.T) {
var tests = []struct {
line string
}{
{">CLIENT:CONNECT,3,4"},
{">CLIENT:ENV,username=bad"},
{">CLIENT:ENV,password=12341234"},
{">CLIENT:ENV,END"},
func Test_ConsumeLinesAcceptsClientIdsAntKeysWithSeveralDigits(t *testing.T) {
var tests = []string{
">CLIENT:CONNECT,115,23",
">CLIENT:ENV,END",
}
fas := fakeAuthenticatorStub{}
fas.authenticated = false
middleware := NewMiddleware(fas.fakeAuthenticator)

var receivedEvent server.ClientEvent
middleware := NewMiddleware(func(e server.ClientEvent) {
receivedEvent = e
})

mockConnection := &management.MockConnection{}
middleware.Start(mockConnection)

for _, test := range tests {
consumed, err := middleware.ConsumeLine(test.line)
assert.NoError(t, err, test.line)
assert.True(t, consumed, test.line)
for _, testLine := range tests {
consumed, err := middleware.ConsumeLine(testLine)
assert.NoError(t, err, testLine)
assert.Equal(t, true, consumed, testLine)
}
assert.Equal(t, "client-deny 3 4 wrong username or password", mockConnection.LastLine)

assert.Equal(
t,
server.ClientEvent{
EventType: server.Connect,
ClientID: 115,
ClientKey: 23,
Env: map[string]string{},
},
receivedEvent,
)
}
Loading

0 comments on commit 7e55222

Please sign in to comment.