Skip to content

Commit

Permalink
Notify connecting/disconnecting state from vpn connection on Win
Browse files Browse the repository at this point in the history
  • Loading branch information
simonhong committed Sep 16, 2021
1 parent 5d29af1 commit d035fb6
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 17 deletions.
59 changes: 50 additions & 9 deletions components/brave_vpn/brave_vpn_os_connection_api_win.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
// (brian@clifton.me)'s work (https://github.com/bsclifton/winvpntool).

using brave_vpn::internal::CheckConnectionResult;
using brave_vpn::internal::CloseEventHandleForConnecting;
using brave_vpn::internal::CreateEntry;
using brave_vpn::internal::GetEventHandleForConnecting;
using brave_vpn::internal::GetPhonebookPath;
using brave_vpn::internal::PrintRasError;
using brave_vpn::internal::RemoveEntry;
Expand Down Expand Up @@ -47,7 +49,9 @@ BraveVPNOSConnectionAPIWin::BraveVPNOSConnectionAPIWin() {
}

BraveVPNOSConnectionAPIWin::~BraveVPNOSConnectionAPIWin() {
CloseHandle(event_handle_);
CloseHandle(event_handle_for_connected_);
CloseHandle(event_handle_for_disconnected_);
CloseEventHandleForConnecting();
}

void BraveVPNOSConnectionAPIWin::CreateVPNConnection(
Expand Down Expand Up @@ -75,6 +79,10 @@ void BraveVPNOSConnectionAPIWin::Connect(const std::string& name) {
}

void BraveVPNOSConnectionAPIWin::Disconnect(const std::string& name) {
// Fire pseudo disconnecting noti because windows doesn't have it.
for (Observer& obs : observers_)
obs.OnIsDisconnecting(name);

// Connection state update from this call will be done by monitoring.
base::ThreadPool::PostTask(
FROM_HERE, {base::MayBlock()},
Expand All @@ -100,7 +108,18 @@ void BraveVPNOSConnectionAPIWin::CheckConnection(const std::string& name) {
void BraveVPNOSConnectionAPIWin::OnObjectSignaled(HANDLE object) {
DCHECK(!target_vpn_entry_name().empty());

CheckConnection(target_vpn_entry_name());
CheckConnectionResult result = CheckConnectionResult::UNKNOWN;
if (object == GetEventHandleForConnecting()) {
result = CheckConnectionResult::CONNECTING;
} else if (object == event_handle_for_connected_) {
result = CheckConnectionResult::CONNECTED;
} else if (object == event_handle_for_disconnected_) {
result = CheckConnectionResult::DISCONNECTED;
} else {
NOTREACHED();
}

OnCheckConnection(target_vpn_entry_name(), result);
}

void BraveVPNOSConnectionAPIWin::OnCheckConnection(
Expand All @@ -109,9 +128,20 @@ void BraveVPNOSConnectionAPIWin::OnCheckConnection(
if (result == CheckConnectionResult::UNKNOWN)
return;

const bool connected = result == CheckConnectionResult::CONNECTED;
for (Observer& obs : observers_) {
connected ? obs.OnConnected(name) : obs.OnDisconnected(name);
switch (result) {
case CheckConnectionResult::CONNECTED:
obs.OnConnected(name);
break;
case CheckConnectionResult::CONNECTING:
obs.OnIsConnecting(name);
break;
case CheckConnectionResult::DISCONNECTED:
obs.OnDisconnected(name);
break;
default:
break;
}
}
}

Expand All @@ -134,13 +164,24 @@ void BraveVPNOSConnectionAPIWin::OnRemoved(const std::string& name,
}

void BraveVPNOSConnectionAPIWin::StartVPNConnectionChangeMonitoring() {
DCHECK(!event_handle_);
event_handle_ = CreateEvent(nullptr, false, false, nullptr);
DCHECK(!event_handle_for_connected_ && !event_handle_for_disconnected_);

event_handle_for_connected_ = CreateEvent(nullptr, false, false, nullptr);
event_handle_for_disconnected_ = CreateEvent(nullptr, false, false, nullptr);

// We don't need to check current connection state again if monitor each event
// separately.
RasConnectionNotificationW(static_cast<HRASCONN>(INVALID_HANDLE_VALUE),
event_handle_for_connected_, RASCN_Connection);
RasConnectionNotificationW(static_cast<HRASCONN>(INVALID_HANDLE_VALUE),
event_handle_,
(RASCN_Connection | RASCN_Disconnection));
watcher_.StartWatchingMultipleTimes(event_handle_, this);
event_handle_for_disconnected_,
RASCN_Disconnection);
connected_event_watcher_.StartWatchingMultipleTimes(
event_handle_for_connected_, this);
disconnected_event_watcher_.StartWatchingMultipleTimes(
event_handle_for_disconnected_, this);
connecting_event_watcher_.StartWatchingMultipleTimes(
GetEventHandleForConnecting(), this);
}

} // namespace brave_vpn
7 changes: 5 additions & 2 deletions components/brave_vpn/brave_vpn_os_connection_api_win.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ class BraveVPNOSConnectionAPIWin : public BraveVPNOSConnectionAPI,

void StartVPNConnectionChangeMonitoring();

HANDLE event_handle_ = nullptr;
base::win::ObjectWatcher watcher_;
HANDLE event_handle_for_connected_ = nullptr;
HANDLE event_handle_for_disconnected_ = nullptr;
base::win::ObjectWatcher connected_event_watcher_;
base::win::ObjectWatcher connecting_event_watcher_;
base::win::ObjectWatcher disconnected_event_watcher_;
base::WeakPtrFactory<BraveVPNOSConnectionAPIWin> weak_factory_{this};
};

Expand Down
73 changes: 68 additions & 5 deletions components/brave_vpn/utils_win.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ namespace brave_vpn {

namespace {

HANDLE g_event_handle = nullptr;

void WINAPI RasDialFunc(UINT, RASCONNSTATE rasconnstate, DWORD error) {
if (error) {
internal::PrintRasError(error);
return;
}

// Only interested in connecting event.
switch (rasconnstate) {
case RASCS_ConnectDevice:
SetEvent(g_event_handle);
break;
default:
// Ignore all other states.
break;
}
}

// https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-formatmessage
void PrintSystemError(DWORD error) {
constexpr DWORD kBufSize = 512;
Expand Down Expand Up @@ -57,6 +76,19 @@ DWORD SetCredentials(LPCTSTR entry_name, LPCTSTR username, LPCTSTR password) {

namespace internal {

HANDLE GetEventHandleForConnecting() {
if (!g_event_handle)
g_event_handle = CreateEvent(nullptr, false, false, nullptr);
return g_event_handle;
}

void CloseEventHandleForConnecting() {
if (g_event_handle) {
CloseHandle(g_event_handle);
g_event_handle = nullptr;
}
}

// https://docs.microsoft.com/en-us/windows/win32/api/ras/nf-ras-rasgeterrorstringa
void PrintRasError(DWORD error) {
constexpr DWORD kBufSize = 512;
Expand Down Expand Up @@ -197,8 +229,8 @@ bool ConnectEntry(const std::wstring& entry_name) {

DVLOG(2) << "Connecting to " << entry_name;
HRASCONN h_ras_conn = NULL;
dw_ret = RasDial(NULL, DEFAULT_PHONE_BOOK, lp_ras_dial_params, NULL, NULL,
&h_ras_conn);
dw_ret = RasDial(NULL, DEFAULT_PHONE_BOOK, lp_ras_dial_params, 0,
(LPVOID)(&RasDialFunc), &h_ras_conn);
if (dw_ret != ERROR_SUCCESS) {
HeapFree(GetProcessHeap(), 0, (LPVOID)lp_ras_dial_params);
PrintRasError(dw_ret);
Expand Down Expand Up @@ -324,6 +356,37 @@ bool CreateEntry(const std::wstring& entry_name,
return true;
}

CheckConnectionResult GetConnectionState(HRASCONN h_ras_conn) {
DWORD dw_ret = 0;

RASCONNSTATUS ras_conn_status;
ZeroMemory(&ras_conn_status, sizeof(RASCONNSTATUS));
ras_conn_status.dwSize = sizeof(RASCONNSTATUS);

// Checking connection status using RasGetConnectStatus
dw_ret = RasGetConnectStatus(h_ras_conn, &ras_conn_status);
if (ERROR_SUCCESS != dw_ret) {
LOG(ERROR) << "RasGetConnectStatus failed: Error = " << dw_ret;
return CheckConnectionResult::UNKNOWN;
}

switch (ras_conn_status.rasconnstate) {
case RASCS_ConnectDevice:
VLOG(2) << "Connecting device...";
return CheckConnectionResult::CONNECTING;
case RASCS_Connected:
VLOG(2) << "Connection completed";
return CheckConnectionResult::CONNECTED;
case RASCS_Disconnected:
VLOG(2) << "Disconnected";
return CheckConnectionResult::DISCONNECTED;
default:
break;
}

return CheckConnectionResult::DISCONNECTED;
}

CheckConnectionResult CheckConnection(const std::wstring& entry_name) {
if (entry_name.empty())
return CheckConnectionResult::UNKNOWN;
Expand All @@ -339,7 +402,7 @@ CheckConnectionResult CheckConnection(const std::wstring& entry_name) {

// If got success here, it means there is no connected vpn entry.
if (dw_ret == ERROR_SUCCESS) {
return CheckConnectionResult::NOT_CONNECTED;
return CheckConnectionResult::DISCONNECTED;
}

// Abnormal situation.
Expand Down Expand Up @@ -367,10 +430,10 @@ CheckConnectionResult CheckConnection(const std::wstring& entry_name) {
}

// If successful, find connection with |entry_name|.
CheckConnectionResult result = CheckConnectionResult::NOT_CONNECTED;
CheckConnectionResult result = CheckConnectionResult::DISCONNECTED;
for (DWORD i = 0; i < dw_connections; i++) {
if (entry_name.compare(lp_ras_conn[i].szEntryName) == 0) {
result = CheckConnectionResult::CONNECTED;
result = GetConnectionState(lp_ras_conn[i].hrasconn);
break;
}
}
Expand Down
6 changes: 5 additions & 1 deletion components/brave_vpn/utils_win.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ namespace internal {

enum class CheckConnectionResult {
CONNECTED,
NOT_CONNECTED,
CONNECTING,
DISCONNECTED,
UNKNOWN,
};

Expand All @@ -29,6 +30,9 @@ bool CreateEntry(const std::wstring& entry_name,
bool RemoveEntry(const std::wstring& entry_name);
bool DisconnectEntry(const std::wstring& entry_name);
bool ConnectEntry(const std::wstring& entry_name);
// Don't cache returned HANDLE. It could be invalidated.
HANDLE GetEventHandleForConnecting();
void CloseEventHandleForConnecting();

CheckConnectionResult CheckConnection(const std::wstring& entry_name);

Expand Down

0 comments on commit d035fb6

Please sign in to comment.