package main

import (
	"context"
	"crypto/tls"
	"flag"
	"fmt"
	"net"
	"net/http"
	"net/http/httputil"
	"net/url"
	"strings"
	"sync"
	"time"

	"github.com/golang/glog"
	"github.com/google/uuid"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/client_golang/prometheus/promhttp"
	"golang.org/x/crypto/acme/autocert"
)

var (
	serverBind   = flag.String("server_bind", "185.230.222.225:13231", "IP address/port to bind proxy server to")
	outboundBind = flag.String("outbound_bind", generateOutbound(), "IP addresses to bind outbound requests to")
	outboundMap  = flag.String("outbound_map", "/users/=https://api.mojang.com/users/,/session/=https://sessionserver.mojang.com/session/", "Comma-separated mapping of [path prefix]=[remote base URL]")

	autocertDomain       = flag.String("autocert_domain", "", "If non-empty, enables autocert for the provided domain.")
	autocertInsecureBind = flag.String("autocert_insecure_bind", "185.230.222.225:80", "Port 80 IP to bind for HTTP-01 ACME challenge.")
	autocertCacheDir     = flag.String("autocert_cache_dir", "/home/minotarproxy/autocertcache", "Autocert cache directory.")
)

func generateOutbound() string {
	var s []string
	for n := 225; n <= 253; n++ {
		s = append(s, fmt.Sprintf("185.230.222.%d", n))
	}
	return strings.Join(s, ",")
}

type roundRobinMetrics struct {
	mInflight  *prometheus.GaugeVec
	mLatency   *prometheus.HistogramVec
	mRequests  *prometheus.CounterVec
	mResponses *prometheus.CounterVec
}

type roundRobinRoundtripper struct {
	mu   sync.Mutex
	sip  []string
	pool []http.RoundTripper
	last int

	m *roundRobinMetrics
}

func (rt *roundRobinRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
	rt.mu.Lock()
	rt.last = (rt.last + 1) % len(rt.pool)
	rt.mu.Unlock()

	handler := handlerName(req.Context())
	labels := prometheus.Labels{"handler": handler}
	gi := rt.m.mInflight.With(labels)
	gi.Inc()
	defer gi.Dec()
	hl := rt.m.mLatency.With(labels)
	rt.m.mRequests.With(labels).Inc()
	cvr, err := rt.m.mResponses.CurryWith(labels)
	if err != nil {
		return nil, err
	}

	startTime := time.Now()
	resp, err := rt.pool[rt.last].RoundTrip(req)
	latency := time.Now().Sub(startTime)

	hl.Observe(latency.Seconds())
	respLabels := prometheus.Labels{"responseCode": fmt.Sprintf("%d", resp.StatusCode), "sourceIP": rt.sip[rt.last]}
	cvr.With(respLabels).Inc()

	return resp, err
}

func NewRoundRobinRoundtripper(ips []string, m *roundRobinMetrics) (*roundRobinRoundtripper, error) {
	var pool []http.RoundTripper
	var sip []string
	for _, oip := range ips {
		a, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%v:0", oip))
		if err != nil {
			return nil, err
		}
		pool = append(pool, &http.Transport{
			DialContext: (&net.Dialer{
				Timeout:   30 * time.Second,
				KeepAlive: 120 * time.Second,
				DualStack: true,
				LocalAddr: a,
			}).DialContext,
			MaxIdleConns:          100,
			IdleConnTimeout:       90 * time.Second,
			TLSHandshakeTimeout:   10 * time.Second,
			ExpectContinueTimeout: 1 * time.Second,
		})
		sip = append(sip, oip)
	}

	return &roundRobinRoundtripper{
		pool: pool,
		sip:  sip,
		m:    m,
	}, nil
}

type key int

const requestIDKey key = 0
const requestStartKey key = 1
const handlerKey key = 2

func newRequestID() string {
	return uuid.New().String()
}

func requestID(ctx context.Context) string {
	return ctx.Value(requestIDKey).(string)
}

func timeSinceStart(ctx context.Context) time.Duration {
	return time.Now().Sub(ctx.Value(requestStartKey).(time.Time))
}

func handlerName(ctx context.Context) string {
	return ctx.Value(handlerKey).(string)
}

func requestIDMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		ctx := r.Context()
		ctx = context.WithValue(ctx, requestIDKey, newRequestID())
		ctx = context.WithValue(ctx, requestStartKey, time.Now())
		next.ServeHTTP(w, r.WithContext(ctx))
	})
}

func attachHandlerMiddleware(next http.Handler, handler string) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		ctx := r.Context()
		ctx = context.WithValue(ctx, handlerKey, handler)
		next.ServeHTTP(w, r.WithContext(ctx))
	})
}

func newReverseProxy(baseURL *url.URL, rrrt http.RoundTripper) http.Handler {
	rp := httputil.NewSingleHostReverseProxy(baseURL)
	d := rp.Director
	rp.Director = func(req *http.Request) {
		d(req)
		req.Header.Set("Host", req.URL.Host)
		req.Host = req.URL.Host
		glog.Infof("REQ [%v] [%v] %v %v (%v)", requestID(req.Context()), req.RemoteAddr, req.Method, req.URL, req.Header.Get("User-Agent"))
	}
	rp.ModifyResponse = func(resp *http.Response) error {
		req := resp.Request
		ctx := req.Context()
		glog.Infof("RESP [%v] [%v] %v %v (%v): [%v] [%v]", requestID(ctx), req.RemoteAddr, req.Method, req.URL, req.Header.Get("User-Agent"), timeSinceStart(ctx), resp.Status)
		return nil
	}
	rp.Transport = rrrt
	return http.StripPrefix(baseURL.Path, rp)
}

type listener struct {
	m    *autocert.Manager
	conf *tls.Config

	tcpListener  net.Listener
	tcpListenErr error
}

func (ln *listener) Accept() (net.Conn, error) {
	if ln.tcpListenErr != nil {
		return nil, ln.tcpListenErr
	}
	conn, err := ln.tcpListener.Accept()
	if err != nil {
		return nil, err
	}
	tcpConn := conn.(*net.TCPConn)
	tcpConn.SetKeepAlive(true)
	tcpConn.SetKeepAlivePeriod(3 * time.Minute)
	return tls.Server(tcpConn, ln.conf), nil
}

func (ln *listener) Addr() net.Addr {
	if ln.tcpListener != nil {
		return ln.tcpListener.Addr()
	}
	// net.Listen failed. Return something non-nil in case callers
	// call Addr before Accept:
	return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 443}
}

func (ln *listener) Close() error {
	if ln.tcpListenErr != nil {
		return ln.tcpListenErr
	}
	return ln.tcpListener.Close()
}

func main() {
	flag.Parse()

	m := &roundRobinMetrics{
		mInflight: prometheus.NewGaugeVec(
			prometheus.GaugeOpts{
				Name: "proxy_inflight",
				Help: "Count of current inflight HTTP requests, per source-IP",
			},
			[]string{"handler"},
		),
		mLatency: prometheus.NewHistogramVec(
			prometheus.HistogramOpts{
				Name: "proxy_latency",
				Help: "Histogram of request latency distributions",
			},
			[]string{"handler"},
		),
		mRequests: prometheus.NewCounterVec(
			prometheus.CounterOpts{
				Name: "proxy_requests",
				Help: "Count of requests that have passed through this proxy",
			},
			[]string{"handler"},
		),
		mResponses: prometheus.NewCounterVec(
			prometheus.CounterOpts{
				Name: "proxy_responses",
				Help: "Count of responses that have passed through this proxy",
			},
			[]string{"handler", "sourceIP", "responseCode"},
		),
	}
	prometheus.MustRegister(m.mInflight, m.mLatency, m.mRequests, m.mResponses)

	outboundIPs := strings.Split(*outboundBind, ",")
	rrrt, err := NewRoundRobinRoundtripper(outboundIPs, m)
	if err != nil {
		glog.Exitf("NewRoundRobinRoundtripper: %v", err)
	}

	remoteMaps := strings.Split(*outboundMap, ",")
	for _, ms := range remoteMaps {
		bits := strings.Split(ms, "=")
		remoteURL, err := url.Parse(bits[1])
		if err != nil {
			glog.Exitf("url.Parse(%q): %v", bits[1], err)
		}
		localPath := bits[0]

		rp := newReverseProxy(remoteURL, rrrt)
		http.Handle(localPath, attachHandlerMiddleware(requestIDMiddleware(rp), bits[1]))
	}

	http.Handle("/metrics", promhttp.Handler())
	glog.Info("Initialization complete. Ready to go.")
	if *autocertDomain != "" {
		m := &autocert.Manager{
			Prompt:     autocert.AcceptTOS,
			HostPolicy: autocert.HostWhitelist(*autocertDomain),
			Cache:      autocert.DirCache(*autocertCacheDir),
		}
		ln := &listener{
			m: m,
			conf: &tls.Config{
				GetCertificate: m.GetCertificate,
				NextProtos:     []string{"h2", "http/1.1"},
			},
		}
		ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", *serverBind)
		go func() {
			glog.Fatal(http.ListenAndServe(*autocertInsecureBind, m.HTTPHandler(nil)))
		}()
		glog.Fatal(http.Serve(ln, nil))
	} else {
		glog.Fatal(http.ListenAndServe(*serverBind, nil))
	}
}