Skip to content

Commit

Permalink
Merge pull request #65 from sonroyaalmerol/rework-stream-proxy
Browse files Browse the repository at this point in the history
Rework load balancing connection algorithm
  • Loading branch information
sonroyaalmerol committed Jul 18, 2024
2 parents e9d611a + d20c0c8 commit 4f970f8
Showing 1 changed file with 89 additions and 150 deletions.
239 changes: 89 additions & 150 deletions stream_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,109 +3,100 @@ package main
import (
"context"
"fmt"
"io"
"log"
"m3u-stream-merger/database"
"m3u-stream-merger/utils"
"net/http"
"os"
"strconv"
"strings"
"time"
)

func loadBalancer(stream database.StreamInfo) (resp *http.Response, selectedUrl *database.StreamURL, err error) {
func loadBalancer(stream database.StreamInfo) (*http.Response, *database.StreamURL, error) {
loadBalancingMode := os.Getenv("LOAD_BALANCING_MODE")
if loadBalancingMode == "" {
loadBalancingMode = "brute-force"
}

switch loadBalancingMode {
case "round-robin":
var lastIndex int // Track the last index used
var lastIndex int

// Round-robin mode
for i := 0; i < len(stream.URLs); i++ {
index := (lastIndex + i) % len(stream.URLs) // Calculate the next index
url := stream.URLs[index]

if checkConcurrency(url.M3UIndex) {
maxCon := os.Getenv(fmt.Sprintf("M3U_MAX_CONCURRENCY_%d", url.M3UIndex))
if strings.TrimSpace(maxCon) == "" {
maxCon = "1"
}
log.Printf("Concurrency limit reached for M3U_%d (max: %s): %s", url.M3UIndex, maxCon, url.Content)
continue // Skip this stream if concurrency limit reached
}

resp, err = utils.CustomHttpRequest("GET", url.Content)
if err == nil {
selectedUrl = &url
break
}
for i := 0; i < len(stream.URLs); i++ {
index := i
if loadBalancingMode == "round-robin" {
index = (lastIndex + i) % len(stream.URLs)
}

// Log the error
log.Printf("Error fetching stream (concurrency round robin mode): %s\n", err.Error())
url := stream.URLs[index]

lastIndex = (lastIndex + 1) % len(stream.URLs) // Update the last index used
if checkConcurrency(url.M3UIndex) {
log.Printf("Concurrency limit reached for M3U_%d: %s", url.M3UIndex, url.Content)
continue
}
case "brute-force":
// Brute force mode
for _, url := range stream.URLs {
if checkConcurrency(url.M3UIndex) {
maxCon := os.Getenv(fmt.Sprintf("M3U_MAX_CONCURRENCY_%d", url.M3UIndex))
if strings.TrimSpace(maxCon) == "" {
maxCon = "1"
}
log.Printf("Concurrency limit reached for M3U_%d (max: %s): %s", url.M3UIndex, maxCon, url.Content)
continue // Skip this stream if concurrency limit reached
}

resp, err = utils.CustomHttpRequest("GET", url.Content)
if err == nil {
selectedUrl = &url
break
}
resp, err := utils.CustomHttpRequest("GET", url.Content)
if err == nil {
return resp, &url, nil
}
log.Printf("Error fetching stream: %s\n", err.Error())

// Log the error
log.Printf("Error fetching stream (concurrency brute force mode): %s\n", err.Error())
if loadBalancingMode == "round-robin" {
lastIndex = (lastIndex + 1) % len(stream.URLs)
}
default:
log.Printf("Invalid LOAD_BALANCING_MODE. Skipping concurrency mode...")
}

if selectedUrl == nil {
log.Printf("All concurrency limits have been reached. Falling back to connection checking mode...\n")
// Connection check mode
for _, url := range stream.URLs {
resp, err = utils.CustomHttpRequest("GET", url.Content)
if err == nil {
selectedUrl = &url
break
} else {
// Log the error
log.Printf("Error fetching stream (connection check mode): %s\n", err.Error())
}
log.Printf("All concurrency limits have been reached. Falling back to connection checking mode...\n")
for _, url := range stream.URLs {
resp, err := utils.CustomHttpRequest("GET", url.Content)
if err == nil {
return resp, &url, nil
}
log.Printf("Error fetching stream: %s\n", err.Error())
}

if resp == nil {
// Log the error
return nil, nil, fmt.Errorf("Error fetching stream. Exhausted all streams.")
}
return nil, nil, fmt.Errorf("Error fetching stream. Exhausted all streams.")
}

func proxyStream(selectedUrl *database.StreamURL, resp *http.Response, r *http.Request, w http.ResponseWriter, statusChan chan int) {
updateConcurrency(selectedUrl.M3UIndex, true)
defer updateConcurrency(selectedUrl.M3UIndex, false)

return resp, selectedUrl, nil
bufferMbInt, _ := strconv.Atoi(os.Getenv("BUFFER_MB"))
if bufferMbInt < 0 {
log.Printf("Invalid BUFFER_MB value: negative integer is not allowed\n")
bufferMbInt = 0
}
buffer := make([]byte, 1024)
if bufferMbInt > 0 {
buffer = make([]byte, bufferMbInt*1024*1024)
}

return resp, selectedUrl, nil
for {
n, err := resp.Body.Read(buffer)
if err != nil {
if err == io.EOF {
log.Printf("Stream ended (EOF reached): %s\n", r.RemoteAddr)
statusChan <- 1
return
}
log.Printf("Error reading stream: %s\n", err.Error())
statusChan <- 1
return
}
if _, err := w.Write(buffer[:n]); err != nil {
log.Printf("Error writing to response: %s\n", err.Error())
statusChan <- 0
return
}
}
}

func streamHandler(w http.ResponseWriter, r *http.Request, db *database.Instance) {
ctx, cancel := context.WithCancel(r.Context())
defer cancel()

// Log the incoming request
log.Printf("Received request from %s for URL: %s\n", r.RemoteAddr, r.URL.Path)

// Extract the m3u ID from the URL path
m3uID := strings.Split(strings.TrimPrefix(r.URL.Path, "/stream/"), ".")[0]
if m3uID == "" {
http.NotFound(w, r)
Expand All @@ -124,25 +115,13 @@ func streamHandler(w http.ResponseWriter, r *http.Request, db *database.Instance
return
}

var resp *http.Response
defer func() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
}()

// Iterate through the streams and select one based on concurrency and availability
var selectedUrl *database.StreamURL

resp, selectedUrl, err = loadBalancer(stream)
resp, selectedUrl, err := loadBalancer(stream)
if err != nil {
http.Error(w, "Error fetching stream. Exhausted all streams.", http.StatusInternalServerError)
return
}
log.Printf("Proxying %s to %s\n", r.RemoteAddr, selectedUrl.Content)

// Log the successful response
log.Printf("Sent stream to %s\n", r.RemoteAddr)
log.Printf("Proxying %s to %s\n", r.RemoteAddr, selectedUrl.Content)

for k, v := range resp.Header {
for _, val := range v {
Expand All @@ -153,87 +132,47 @@ func streamHandler(w http.ResponseWriter, r *http.Request, db *database.Instance
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Access-Control-Allow-Origin", "*")

// Set initial timer duration
timerDuration := 15 * time.Second
timer := time.NewTimer(timerDuration)

// Function to reset the timer
resetTimer := func() {
timer.Reset(timerDuration)
}

go func() {
updateConcurrency(selectedUrl.M3UIndex, true)

bufferMbInt := 0
bufferMb := os.Getenv("BUFFER_MB")
if bufferMb != "" {
bufferMbInt, err = strconv.Atoi(bufferMb)
if err != nil {
log.Printf("Invalid BUFFER_MB value: %s\n", err.Error())
bufferMbInt = 0
}

if bufferMbInt < 0 {
log.Printf("Invalid BUFFER_MB value: negative integer is not allowed\n")
}
}

buffer := make([]byte, 512)
if bufferMbInt > 0 {
log.Printf("Buffer is set to %dmb.\n", bufferMbInt)
buffer = make([]byte, 1024*bufferMbInt)
}
for {
select {
case <-timer.C:
log.Printf("Connection timed out: %s\n", r.RemoteAddr)

log.Printf("Closing (%s) connection.\n", r.RemoteAddr)
cancel()
return
default:
}
n, err := resp.Body.Read(buffer)
if err != nil {
log.Printf("Error reading stream: %s\n", err.Error())
break
}
if n > 0 {
resetTimer()
_, err := w.Write(buffer[:n])
for {
select {
case <-ctx.Done():
log.Printf("Client disconnected: %s\n", r.RemoteAddr)
resp.Body.Close()
break
default:
exitStatus := make(chan int)
go proxyStream(selectedUrl, resp, r, w, exitStatus)
streamExitCode := <-exitStatus

if streamExitCode == 1 {
// Retry on server-side connection errors
log.Printf("Server connection failed: %s\n", selectedUrl.Content)
log.Printf("Retrying other servers...\n")
resp.Body.Close()
resp, selectedUrl, err = loadBalancer(stream)
if err != nil {
log.Printf("Error writing to response: %s\n", err.Error())
break
http.Error(w, "Error fetching stream. Exhausted all streams.", http.StatusInternalServerError)
return
}
log.Printf("Reconnected to %s\n", selectedUrl.Content)
} else {
// Consider client-side connection errors as complete closure
log.Printf("Client has closed the stream: %s\n", r.RemoteAddr)
cancel()
}
}

log.Printf("Closing (%s) connection.\n", r.RemoteAddr)
cancel()
}()

// Wait for the request context to be canceled or the stream to finish
<-ctx.Done()
log.Printf("Client (%s) disconnected.\n", r.RemoteAddr)
updateConcurrency(selectedUrl.M3UIndex, false)
}
}

func checkConcurrency(m3uIndex int) bool {
maxConcurrency := 1
var err error
rawMaxConcurrency, maxConcurrencyExists := os.LookupEnv(fmt.Sprintf("M3U_MAX_CONCURRENCY_%d", m3uIndex))
if maxConcurrencyExists {
maxConcurrency, err = strconv.Atoi(rawMaxConcurrency)
if err != nil {
maxConcurrency = 1
}
maxConcurrency, err := strconv.Atoi(os.Getenv(fmt.Sprintf("M3U_MAX_CONCURRENCY_%d", m3uIndex)))
if err != nil {
maxConcurrency = 1
}

count, err := database.GetConcurrency(m3uIndex)
if err != nil {
log.Printf("Error checking concurrency: %s\n", err.Error())
return false // Error occurred, treat as concurrency not reached
return false
}

log.Printf("Current number of connections for M3U_%d: %d", m3uIndex, count)
Expand Down

0 comments on commit 4f970f8

Please sign in to comment.