depot/go/minotarproxy/minotarproxy.go
2022-01-22 15:28:31 +00:00

288 lines
7.8 KiB
Go

package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"sync"
"time"
"github.com/golang/glog"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/crypto/acme/autocert"
)
var (
serverBind = flag.String("server_bind", "185.230.222.225:13231", "IP address/port to bind proxy server to")
outboundBind = flag.String("outbound_bind", generateOutbound(), "IP addresses to bind outbound requests to")
outboundMap = flag.String("outbound_map", "/users/=https://api.mojang.com/users/,/session/=https://sessionserver.mojang.com/session/", "Comma-separated mapping of [path prefix]=[remote base URL]")
autocertDomain = flag.String("autocert_domain", "", "If non-empty, enables autocert for the provided domain.")
autocertInsecureBind = flag.String("autocert_insecure_bind", "185.230.222.225:80", "Port 80 IP to bind for HTTP-01 ACME challenge.")
autocertCacheDir = flag.String("autocert_cache_dir", "/home/minotarproxy/autocertcache", "Autocert cache directory.")
)
func generateOutbound() string {
var s []string
for n := 225; n <= 253; n++ {
s = append(s, fmt.Sprintf("185.230.222.%d", n))
}
return strings.Join(s, ",")
}
type roundRobinMetrics struct {
mInflight *prometheus.GaugeVec
mLatency *prometheus.HistogramVec
mRequests *prometheus.CounterVec
mResponses *prometheus.CounterVec
}
type roundRobinRoundtripper struct {
mu sync.Mutex
sip []string
pool []http.RoundTripper
last int
m *roundRobinMetrics
}
func (rt *roundRobinRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
rt.mu.Lock()
rt.last = (rt.last + 1) % len(rt.pool)
rt.mu.Unlock()
handler := handlerName(req.Context())
labels := prometheus.Labels{"handler": handler}
gi := rt.m.mInflight.With(labels)
gi.Inc()
defer gi.Dec()
hl := rt.m.mLatency.With(labels)
rt.m.mRequests.With(labels).Inc()
cvr, err := rt.m.mResponses.CurryWith(labels)
if err != nil {
return nil, err
}
startTime := time.Now()
resp, err := rt.pool[rt.last].RoundTrip(req)
latency := time.Now().Sub(startTime)
hl.Observe(latency.Seconds())
respLabels := prometheus.Labels{"responseCode": fmt.Sprintf("%d", resp.StatusCode), "sourceIP": rt.sip[rt.last]}
cvr.With(respLabels).Inc()
return resp, err
}
func NewRoundRobinRoundtripper(ips []string, m *roundRobinMetrics) (*roundRobinRoundtripper, error) {
var pool []http.RoundTripper
var sip []string
for _, oip := range ips {
a, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%v:0", oip))
if err != nil {
return nil, err
}
pool = append(pool, &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 120 * time.Second,
DualStack: true,
LocalAddr: a,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
})
sip = append(sip, oip)
}
return &roundRobinRoundtripper{
pool: pool,
sip: sip,
m: m,
}, nil
}
type key int
const requestIDKey key = 0
const requestStartKey key = 1
const handlerKey key = 2
func newRequestID() string {
return uuid.New().String()
}
func requestID(ctx context.Context) string {
return ctx.Value(requestIDKey).(string)
}
func timeSinceStart(ctx context.Context) time.Duration {
return time.Now().Sub(ctx.Value(requestStartKey).(time.Time))
}
func handlerName(ctx context.Context) string {
return ctx.Value(handlerKey).(string)
}
func requestIDMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, requestIDKey, newRequestID())
ctx = context.WithValue(ctx, requestStartKey, time.Now())
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func attachHandlerMiddleware(next http.Handler, handler string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, handlerKey, handler)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func newReverseProxy(baseURL *url.URL, rrrt http.RoundTripper) http.Handler {
rp := httputil.NewSingleHostReverseProxy(baseURL)
d := rp.Director
rp.Director = func(req *http.Request) {
d(req)
req.Header.Set("Host", req.URL.Host)
req.Host = req.URL.Host
glog.Infof("REQ [%v] [%v] %v %v (%v)", requestID(req.Context()), req.RemoteAddr, req.Method, req.URL, req.Header.Get("User-Agent"))
}
rp.ModifyResponse = func(resp *http.Response) error {
req := resp.Request
ctx := req.Context()
glog.Infof("RESP [%v] [%v] %v %v (%v): [%v] [%v]", requestID(ctx), req.RemoteAddr, req.Method, req.URL, req.Header.Get("User-Agent"), timeSinceStart(ctx), resp.Status)
return nil
}
rp.Transport = rrrt
return http.StripPrefix(baseURL.Path, rp)
}
type listener struct {
m *autocert.Manager
conf *tls.Config
tcpListener net.Listener
tcpListenErr error
}
func (ln *listener) Accept() (net.Conn, error) {
if ln.tcpListenErr != nil {
return nil, ln.tcpListenErr
}
conn, err := ln.tcpListener.Accept()
if err != nil {
return nil, err
}
tcpConn := conn.(*net.TCPConn)
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(3 * time.Minute)
return tls.Server(tcpConn, ln.conf), nil
}
func (ln *listener) Addr() net.Addr {
if ln.tcpListener != nil {
return ln.tcpListener.Addr()
}
// net.Listen failed. Return something non-nil in case callers
// call Addr before Accept:
return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 443}
}
func (ln *listener) Close() error {
if ln.tcpListenErr != nil {
return ln.tcpListenErr
}
return ln.tcpListener.Close()
}
func main() {
flag.Parse()
m := &roundRobinMetrics{
mInflight: prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "proxy_inflight",
Help: "Count of current inflight HTTP requests, per source-IP",
},
[]string{"handler"},
),
mLatency: prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "proxy_latency",
Help: "Histogram of request latency distributions",
},
[]string{"handler"},
),
mRequests: prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "proxy_requests",
Help: "Count of requests that have passed through this proxy",
},
[]string{"handler"},
),
mResponses: prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "proxy_responses",
Help: "Count of responses that have passed through this proxy",
},
[]string{"handler", "sourceIP", "responseCode"},
),
}
prometheus.MustRegister(m.mInflight, m.mLatency, m.mRequests, m.mResponses)
outboundIPs := strings.Split(*outboundBind, ",")
rrrt, err := NewRoundRobinRoundtripper(outboundIPs, m)
if err != nil {
glog.Exitf("NewRoundRobinRoundtripper: %v", err)
}
remoteMaps := strings.Split(*outboundMap, ",")
for _, ms := range remoteMaps {
bits := strings.Split(ms, "=")
remoteURL, err := url.Parse(bits[1])
if err != nil {
glog.Exitf("url.Parse(%q): %v", bits[1], err)
}
localPath := bits[0]
rp := newReverseProxy(remoteURL, rrrt)
http.Handle(localPath, attachHandlerMiddleware(requestIDMiddleware(rp), bits[1]))
}
http.Handle("/metrics", promhttp.Handler())
glog.Info("Initialization complete. Ready to go.")
if *autocertDomain != "" {
m := &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(*autocertDomain),
Cache: autocert.DirCache(*autocertCacheDir),
}
ln := &listener{
m: m,
conf: &tls.Config{
GetCertificate: m.GetCertificate,
NextProtos: []string{"h2", "http/1.1"},
},
}
ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", *serverBind)
go func() {
glog.Fatal(http.ListenAndServe(*autocertInsecureBind, m.HTTPHandler(nil)))
}()
glog.Fatal(http.Serve(ln, nil))
} else {
glog.Fatal(http.ListenAndServe(*serverBind, nil))
}
}