diff --git a/components/brave_vpn/brave_vpn_os_connection_api_win.cc b/components/brave_vpn/brave_vpn_os_connection_api_win.cc index d587858c98b5..7a3c969466c8 100644 --- a/components/brave_vpn/brave_vpn_os_connection_api_win.cc +++ b/components/brave_vpn/brave_vpn_os_connection_api_win.cc @@ -17,7 +17,9 @@ // (brian@clifton.me)'s work (https://github.com/bsclifton/winvpntool). using brave_vpn::internal::CheckConnectionResult; +using brave_vpn::internal::CloseEventHandleForConnecting; using brave_vpn::internal::CreateEntry; +using brave_vpn::internal::GetEventHandleForConnecting; using brave_vpn::internal::GetPhonebookPath; using brave_vpn::internal::PrintRasError; using brave_vpn::internal::RemoveEntry; @@ -47,7 +49,9 @@ BraveVPNOSConnectionAPIWin::BraveVPNOSConnectionAPIWin() { } BraveVPNOSConnectionAPIWin::~BraveVPNOSConnectionAPIWin() { - CloseHandle(event_handle_); + CloseHandle(event_handle_for_connected_); + CloseHandle(event_handle_for_disconnected_); + CloseEventHandleForConnecting(); } void BraveVPNOSConnectionAPIWin::CreateVPNConnection( @@ -75,6 +79,10 @@ void BraveVPNOSConnectionAPIWin::Connect(const std::string& name) { } void BraveVPNOSConnectionAPIWin::Disconnect(const std::string& name) { + // Fire pseudo disconnecting noti because windows doesn't have it. + for (Observer& obs : observers_) + obs.OnIsDisconnecting(name); + // Connection state update from this call will be done by monitoring. base::ThreadPool::PostTask( FROM_HERE, {base::MayBlock()}, @@ -100,7 +108,18 @@ void BraveVPNOSConnectionAPIWin::CheckConnection(const std::string& name) { void BraveVPNOSConnectionAPIWin::OnObjectSignaled(HANDLE object) { DCHECK(!target_vpn_entry_name().empty()); - CheckConnection(target_vpn_entry_name()); + CheckConnectionResult result = CheckConnectionResult::UNKNOWN; + if (object == GetEventHandleForConnecting()) { + result = CheckConnectionResult::CONNECTING; + } else if (object == event_handle_for_connected_) { + result = CheckConnectionResult::CONNECTED; + } else if (object == event_handle_for_disconnected_) { + result = CheckConnectionResult::DISCONNECTED; + } else { + NOTREACHED(); + } + + OnCheckConnection(target_vpn_entry_name(), result); } void BraveVPNOSConnectionAPIWin::OnCheckConnection( @@ -109,9 +128,20 @@ void BraveVPNOSConnectionAPIWin::OnCheckConnection( if (result == CheckConnectionResult::UNKNOWN) return; - const bool connected = result == CheckConnectionResult::CONNECTED; for (Observer& obs : observers_) { - connected ? obs.OnConnected(name) : obs.OnDisconnected(name); + switch (result) { + case CheckConnectionResult::CONNECTED: + obs.OnConnected(name); + break; + case CheckConnectionResult::CONNECTING: + obs.OnIsConnecting(name); + break; + case CheckConnectionResult::DISCONNECTED: + obs.OnDisconnected(name); + break; + default: + break; + } } } @@ -134,13 +164,24 @@ void BraveVPNOSConnectionAPIWin::OnRemoved(const std::string& name, } void BraveVPNOSConnectionAPIWin::StartVPNConnectionChangeMonitoring() { - DCHECK(!event_handle_); - event_handle_ = CreateEvent(nullptr, false, false, nullptr); + DCHECK(!event_handle_for_connected_ && !event_handle_for_disconnected_); + + event_handle_for_connected_ = CreateEvent(nullptr, false, false, nullptr); + event_handle_for_disconnected_ = CreateEvent(nullptr, false, false, nullptr); + // We don't need to check current connection state again if monitor each event + // separately. + RasConnectionNotificationW(static_cast(INVALID_HANDLE_VALUE), + event_handle_for_connected_, RASCN_Connection); RasConnectionNotificationW(static_cast(INVALID_HANDLE_VALUE), - event_handle_, - (RASCN_Connection | RASCN_Disconnection)); - watcher_.StartWatchingMultipleTimes(event_handle_, this); + event_handle_for_disconnected_, + RASCN_Disconnection); + connected_event_watcher_.StartWatchingMultipleTimes( + event_handle_for_connected_, this); + disconnected_event_watcher_.StartWatchingMultipleTimes( + event_handle_for_disconnected_, this); + connecting_event_watcher_.StartWatchingMultipleTimes( + GetEventHandleForConnecting(), this); } } // namespace brave_vpn diff --git a/components/brave_vpn/brave_vpn_os_connection_api_win.h b/components/brave_vpn/brave_vpn_os_connection_api_win.h index a72dae851506..5dcc615603a3 100644 --- a/components/brave_vpn/brave_vpn_os_connection_api_win.h +++ b/components/brave_vpn/brave_vpn_os_connection_api_win.h @@ -51,8 +51,11 @@ class BraveVPNOSConnectionAPIWin : public BraveVPNOSConnectionAPI, void StartVPNConnectionChangeMonitoring(); - HANDLE event_handle_ = nullptr; - base::win::ObjectWatcher watcher_; + HANDLE event_handle_for_connected_ = nullptr; + HANDLE event_handle_for_disconnected_ = nullptr; + base::win::ObjectWatcher connected_event_watcher_; + base::win::ObjectWatcher connecting_event_watcher_; + base::win::ObjectWatcher disconnected_event_watcher_; base::WeakPtrFactory weak_factory_{this}; }; diff --git a/components/brave_vpn/utils_win.cc b/components/brave_vpn/utils_win.cc index 29cc018fbbfa..f1a77aa6b296 100644 --- a/components/brave_vpn/utils_win.cc +++ b/components/brave_vpn/utils_win.cc @@ -18,6 +18,25 @@ namespace brave_vpn { namespace { +HANDLE g_event_handle = nullptr; + +void WINAPI RasDialFunc(UINT, RASCONNSTATE rasconnstate, DWORD error) { + if (error) { + internal::PrintRasError(error); + return; + } + + // Only interested in connecting event. + switch (rasconnstate) { + case RASCS_ConnectDevice: + SetEvent(g_event_handle); + break; + default: + // Ignore all other states. + break; + } +} + // https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-formatmessage void PrintSystemError(DWORD error) { constexpr DWORD kBufSize = 512; @@ -57,6 +76,19 @@ DWORD SetCredentials(LPCTSTR entry_name, LPCTSTR username, LPCTSTR password) { namespace internal { +HANDLE GetEventHandleForConnecting() { + if (!g_event_handle) + g_event_handle = CreateEvent(nullptr, false, false, nullptr); + return g_event_handle; +} + +void CloseEventHandleForConnecting() { + if (g_event_handle) { + CloseHandle(g_event_handle); + g_event_handle = nullptr; + } +} + // https://docs.microsoft.com/en-us/windows/win32/api/ras/nf-ras-rasgeterrorstringa void PrintRasError(DWORD error) { constexpr DWORD kBufSize = 512; @@ -197,8 +229,8 @@ bool ConnectEntry(const std::wstring& entry_name) { DVLOG(2) << "Connecting to " << entry_name; HRASCONN h_ras_conn = NULL; - dw_ret = RasDial(NULL, DEFAULT_PHONE_BOOK, lp_ras_dial_params, NULL, NULL, - &h_ras_conn); + dw_ret = RasDial(NULL, DEFAULT_PHONE_BOOK, lp_ras_dial_params, 0, + (LPVOID)(&RasDialFunc), &h_ras_conn); if (dw_ret != ERROR_SUCCESS) { HeapFree(GetProcessHeap(), 0, (LPVOID)lp_ras_dial_params); PrintRasError(dw_ret); @@ -324,6 +356,37 @@ bool CreateEntry(const std::wstring& entry_name, return true; } +CheckConnectionResult GetConnectionState(HRASCONN h_ras_conn) { + DWORD dw_ret = 0; + + RASCONNSTATUS ras_conn_status; + ZeroMemory(&ras_conn_status, sizeof(RASCONNSTATUS)); + ras_conn_status.dwSize = sizeof(RASCONNSTATUS); + + // Checking connection status using RasGetConnectStatus + dw_ret = RasGetConnectStatus(h_ras_conn, &ras_conn_status); + if (ERROR_SUCCESS != dw_ret) { + LOG(ERROR) << "RasGetConnectStatus failed: Error = " << dw_ret; + return CheckConnectionResult::UNKNOWN; + } + + switch (ras_conn_status.rasconnstate) { + case RASCS_ConnectDevice: + VLOG(2) << "Connecting device..."; + return CheckConnectionResult::CONNECTING; + case RASCS_Connected: + VLOG(2) << "Connection completed"; + return CheckConnectionResult::CONNECTED; + case RASCS_Disconnected: + VLOG(2) << "Disconnected"; + return CheckConnectionResult::DISCONNECTED; + default: + break; + } + + return CheckConnectionResult::DISCONNECTED; +} + CheckConnectionResult CheckConnection(const std::wstring& entry_name) { if (entry_name.empty()) return CheckConnectionResult::UNKNOWN; @@ -339,7 +402,7 @@ CheckConnectionResult CheckConnection(const std::wstring& entry_name) { // If got success here, it means there is no connected vpn entry. if (dw_ret == ERROR_SUCCESS) { - return CheckConnectionResult::NOT_CONNECTED; + return CheckConnectionResult::DISCONNECTED; } // Abnormal situation. @@ -367,10 +430,10 @@ CheckConnectionResult CheckConnection(const std::wstring& entry_name) { } // If successful, find connection with |entry_name|. - CheckConnectionResult result = CheckConnectionResult::NOT_CONNECTED; + CheckConnectionResult result = CheckConnectionResult::DISCONNECTED; for (DWORD i = 0; i < dw_connections; i++) { if (entry_name.compare(lp_ras_conn[i].szEntryName) == 0) { - result = CheckConnectionResult::CONNECTED; + result = GetConnectionState(lp_ras_conn[i].hrasconn); break; } } diff --git a/components/brave_vpn/utils_win.h b/components/brave_vpn/utils_win.h index 8559e15ea818..e8bfb98f5f4f 100644 --- a/components/brave_vpn/utils_win.h +++ b/components/brave_vpn/utils_win.h @@ -15,7 +15,8 @@ namespace internal { enum class CheckConnectionResult { CONNECTED, - NOT_CONNECTED, + CONNECTING, + DISCONNECTED, UNKNOWN, }; @@ -29,6 +30,9 @@ bool CreateEntry(const std::wstring& entry_name, bool RemoveEntry(const std::wstring& entry_name); bool DisconnectEntry(const std::wstring& entry_name); bool ConnectEntry(const std::wstring& entry_name); +// Don't cache returned HANDLE. It could be invalidated. +HANDLE GetEventHandleForConnecting(); +void CloseEventHandleForConnecting(); CheckConnectionResult CheckConnection(const std::wstring& entry_name);