diff --git a/src/tailscale/cli.ts b/src/tailscale/cli.ts index db6805d..f393888 100644 --- a/src/tailscale/cli.ts +++ b/src/tailscale/cli.ts @@ -154,11 +154,17 @@ export class Tailscale { Logger.info(`path: ${binPath}`, LOG_COMPONENT); this.notifyExit = () => { Logger.info('starting sudo tsrelay'); - const childProcess = cp.spawn(`/usr/bin/pkexec`, [ - '--disable-internal-agent', - binPath, - ...args, - ]); + let authCmd = `/usr/bin/pkexec`; + let authArgs = ['--disable-internal-agent', binPath, ...args]; + if ( + process.env['container'] === 'flatpak' && + process.env['FLATPAK_ID'] && + process.env['FLATPAK_ID'].startsWith('com.visualstudio.code') + ) { + authCmd = 'flatpak-spawn'; + authArgs = ['--host', 'pkexec', '--disable-internal-agent', binPath, ...args]; + } + const childProcess = cp.spawn(authCmd, authArgs); childProcess.on('exit', async (code) => { Logger.warn(`sudo child process exited with code ${code}`, LOG_COMPONENT); if (code === 0) { diff --git a/src/tailscale/error.ts b/src/tailscale/error.ts index bdaad0e..36963f5 100644 --- a/src/tailscale/error.ts +++ b/src/tailscale/error.ts @@ -43,6 +43,11 @@ export function errorForType(type: string): TailscaleError { }, ], }; + case 'FLATPAK_REQUIRES_RESTART': + return { + title: 'Restart Flatpak Container', + message: 'Please quit VSCode and restart the container to finish setting up Tailscale', + }; default: return { title: 'Unknown error', diff --git a/src/types.ts b/src/types.ts index 29b194a..419e99b 100644 --- a/src/types.ts +++ b/src/types.ts @@ -33,8 +33,14 @@ export interface WithErrors { Errors?: RelayError[]; } -interface RelayError { - Type: 'FUNNEL_OFF' | 'HTTPS_OFF' | 'OFFLINE' | 'REQUIRES_SUDO' | 'NOT_RUNNING'; +export interface RelayError { + Type: + | 'FUNNEL_OFF' + | 'HTTPS_OFF' + | 'OFFLINE' + | 'REQUIRES_SUDO' + | 'NOT_RUNNING' + | 'FLATPAK_REQUIRES_RESTART'; } interface PeerStatus { diff --git a/tsrelay/main.go b/tsrelay/main.go index a6488f7..b886b9c 100644 --- a/tsrelay/main.go +++ b/tsrelay/main.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "os" + "os/exec" "os/signal" "strconv" "strings" @@ -56,8 +57,13 @@ const ( // NotRunning indicates tailscaled is // not running NotRunning = "NOT_RUNNING" + // FlatpakRequiresRestart indicates that the flatpak + // container needs to be fully restarted + FlatpakRequiresRestart = "FLATPAK_REQUIRES_RESTART" ) +var requiresRestart bool + func main() { must(run()) } @@ -78,12 +84,50 @@ func run() error { Logger: log.New(logOut, "", 0), } + flatpakID := os.Getenv("FLATPAK_ID") + isFlatpak := os.Getenv("container") == "flatpak" && strings.HasPrefix(flatpakID, "com.visualstudio.code") + if isFlatpak { + lggr.Println("running inside flatpak") + var err error + requiresRestart, err = ensureTailscaledAccessible(lggr, flatpakID) + if err != nil { + return err + } + lggr.Printf("requires restart: %v", requiresRestart) + } + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) defer cancel() return runHTTPServer(ctx, lggr, *port, *nonce) } +func ensureTailscaledAccessible(lggr *logger, flatpakID string) (bool, error) { + _, err := os.Stat("/run/tailscale") + if err == nil { + lggr.Println("tailscaled is accessible") + return false, nil + } + if !errors.Is(err, os.ErrNotExist) { + return false, fmt.Errorf("error checking /run/tailscale: %w", err) + } + lggr.Println("running flatpak override") + cmd := exec.Command( + "flatpak-spawn", + "--host", + "flatpak", + "override", + "--user", + flatpakID, + "--filesystem=/run/tailscale", + ) + output, err := cmd.Output() + if err != nil { + return false, fmt.Errorf("error running flatpak override: %s - %w", output, err) + } + return true, nil +} + type serverDetails struct { Address string `json:"address,omitempty"` Nonce string `json:"nonce,omitempty"` @@ -252,6 +296,12 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.Write([]byte(`{}`)) case http.MethodGet: + if requiresRestart { + json.NewEncoder(w).Encode(RelayError{ + Errors: []Error{{Type: FlatpakRequiresRestart}}, + }) + return + } var wg sync.WaitGroup wg.Add(1) portMap := map[uint16]string{} @@ -414,8 +464,8 @@ func (h *httpHandler) getConfigs(ctx context.Context) (*ipnstate.Status, *ipn.Se var ( st *ipnstate.Status sc *ipn.ServeConfig - g errgroup.Group ) + g, ctx := errgroup.WithContext(ctx) g.Go(func() error { var err error sc, err = h.lc.GetServeConfig(ctx)