package main

import (
	"context"
	"errors"
	"flag"
	"fmt"
	"log"
	"net/http"
	"net/url"
	"regexp"
	"strconv"
	"strings"
	"time"

	"github.com/jackc/pgx/v4"
	"github.com/jackc/pgx/v4/pgxpool"
	"golang.org/x/sync/errgroup"
	"google.golang.org/grpc"
	"google.golang.org/protobuf/encoding/protojson"
	"hg.lukegb.com/lukegb/depot/go/trains/webapi"
	"hg.lukegb.com/lukegb/depot/go/trains/webapi/summarize"
)

var (
	allowStress = flag.Bool("allow_stress", false, "Allow checking the validity of every record in the database")

	certFile = flag.String("cert_file", "", "Path to TLS certificate. If empty, listens on plaintext HTTP instead.")
	keyFile  = flag.String("key_file", "", "Path to TLS private key.")
)

type watchingResponseWriter struct {
	http.ResponseWriter
	http.Flusher
	WrittenTo         bool
	WroteStatus       int
	WroteContentBytes int
}

func (rw *watchingResponseWriter) Flush() {
	if rw.Flusher != nil {
		rw.Flusher.Flush()
	}
}

func (rw *watchingResponseWriter) Header() http.Header { return rw.ResponseWriter.Header() }
func (rw *watchingResponseWriter) Write(bs []byte) (int, error) {
	if !rw.WrittenTo {
		rw.WroteStatus = http.StatusOK
	}
	rw.WrittenTo = true
	rw.WroteContentBytes += len(bs)
	return rw.ResponseWriter.Write(bs)
}
func (rw *watchingResponseWriter) WriteHeader(statusCode int) {
	if rw.WrittenTo {
		return
	}
	rw.WrittenTo = true
	rw.WroteStatus = statusCode
	rw.ResponseWriter.WriteHeader(statusCode)
}

type httpError struct {
	err            error
	httpStatusCode int
	publicError    string
}

func (he httpError) Error() string {
	if he.err == nil {
		return fmt.Sprintf("HTTP %d: %v", he.httpStatusCode, he.publicError)
	}
	return he.err.Error()
}

type grpcServer struct {
	webapi.UnimplementedTrainsServiceServer

	dbPool *pgxpool.Pool
}

func (s *grpcServer) SummarizeService(ctx context.Context, req *webapi.SummarizeServiceRequest) (*webapi.ServiceData, error) {
	conn, err := s.dbPool.Acquire(ctx)
	if err != nil {
		return nil, fmt.Errorf("dbPool.Acquire: %w", err)
	}
	defer conn.Release()

	return summarize.Service(ctx, conn, int(req.GetId()))
}

func (s *grpcServer) ListServices(ctx context.Context, req *webapi.ListServicesRequest) (*webapi.ServiceList, error) {
	conn, err := s.dbPool.Acquire(ctx)
	if err != nil {
		return nil, fmt.Errorf("dbPool.Acquire: %w", err)
	}
	defer conn.Release()

	sds, err := summarize.ListServices(ctx, conn, req)
	if err != nil {
		return nil, fmt.Errorf("ListServices: %w", err)
	}

	return &webapi.ServiceList{Services: sds}, nil
}

func (s *grpcServer) ListLocations(ctx context.Context, req *webapi.ListLocationsRequest) (*webapi.LocationList, error) {
	conn, err := s.dbPool.Acquire(ctx)
	if err != nil {
		return nil, fmt.Errorf("dbPool.Acquire: %w", err)
	}
	defer conn.Release()

	var locs []*webapi.Location
	where := "TRUE"
	if !req.GetIncludeWithoutCrs() {
		where = "rl.crs IS NOT NULL AND rl.crs <> ''"
	}
	rows, err := conn.Query(ctx, fmt.Sprintf(`
SELECT
	rl.tiploc, rl.name, rl.crs,
	rltoc.toc, rltoc.name, rltoc.url
FROM 
	ref_locations rl
INNER JOIN
	ref_tocs rltoc ON rltoc.toc = rl.toc
WHERE %s`, where))
	if err != nil {
		return nil, fmt.Errorf("querying for locations: %w", err)
	}
	for rows.Next() {
		loc := &webapi.Location{Operator: &webapi.TrainOperator{}}
		if err := rows.Scan(&loc.Tiploc, &loc.Name, &loc.Crs, &loc.Operator.Code, &loc.Operator.Name, &loc.Operator.Url); err != nil {
			return nil, fmt.Errorf("scanning for locations: %w", err)
		}
		loc.Canonicalize()
		locs = append(locs, loc)
	}
	if rows.Err() != nil {
		return nil, fmt.Errorf("rows error: %w", rows.Err())
	}

	return &webapi.LocationList{Locations: locs}, nil
}

type server struct {
	dbPool     *pgxpool.Pool
	grpcServer *grpc.Server
}

func (s *server) stress(ctx context.Context, rw http.ResponseWriter, r *http.Request) error {
	if r.URL.Path != "/stress" {
		return httpError{
			httpStatusCode: http.StatusNotFound, publicError: "not found",
		}
	}

	// stress test mode, let's go
	const numWorkers = 32
	ch := make(chan int, numWorkers)
	eg, egCtx := errgroup.WithContext(ctx)
	for n := 0; n < numWorkers; n++ {
		eg.Go(func() error {
			ctx := egCtx
			conn, err := s.dbPool.Acquire(ctx)
			if err != nil {
				return fmt.Errorf("acquiring conn: %w", err)
			}
			defer conn.Release()

			for {
				select {
				case <-ctx.Done():
					return ctx.Err()
				case tsid := <-ch:
					if tsid == 0 {
						return nil
					}
					svc, err := summarize.Service(ctx, conn, tsid)
					if err != nil {
						return fmt.Errorf("summarize.Service %v: %w", tsid, err)
					}

					if _, err := protojson.Marshal(svc); err != nil {
						return fmt.Errorf("protojson.Marshal %v: %w", tsid, err)
					}
					log.Println(tsid)
				}
			}
			return nil
		})
	}

	idConn, err := s.dbPool.Acquire(ctx)
	if err != nil {
		return fmt.Errorf("acquiring idConn: %w", err)
	}
	defer idConn.Release()

	rows, err := idConn.Query(ctx, "SELECT id FROM train_services ORDER BY id")
	if err != nil {
		return fmt.Errorf("querying for all IDs: %w", err)
	}
	defer rows.Close()

loop:
	for rows.Next() {
		var w int
		if err := rows.Scan(&w); err != nil {
			return fmt.Errorf("scanning ID: %w", err)
		}

		select {
		case ch <- w:
		case <-egCtx.Done():
			break loop
		}
	}
	close(ch)

	return eg.Wait()
}

var jsonPathRegexp = regexp.MustCompile(`/([1-9][0-9]*)$`)

func (s *server) handleJSON(ctx context.Context, rw http.ResponseWriter, r *http.Request) error {
	// /<id>
	ms := jsonPathRegexp.FindStringSubmatch(r.URL.Path)
	if len(ms) != 2 {
		return httpError{
			httpStatusCode: http.StatusNotFound,
			publicError:    "not found",
		}
	}
	id, err := strconv.Atoi(ms[1])
	if err != nil {
		return httpError{
			err:            err,
			httpStatusCode: http.StatusBadRequest,
			publicError:    "bad id",
		}
	}

	conn, err := s.dbPool.Acquire(ctx)
	if err != nil {
		return fmt.Errorf("acquiring database connection: %w", err)
	}
	defer conn.Release()

	svc, err := summarize.Service(ctx, conn, id)
	if errors.Is(err, pgx.ErrNoRows) {
		return httpError{
			err:            err,
			httpStatusCode: http.StatusNotFound,
			publicError:    "no train with that ID",
		}
	} else if err != nil {
		return err
	}

	rw.Header().Set("Content-Type", "application/json; charset=utf-8")
	svcBytes, err := protojson.Marshal(svc)
	if err != nil {
		return fmt.Errorf("encoding JSON: %w", err)
	}
	if _, err := rw.Write(svcBytes); err != nil {
		return fmt.Errorf("writing JSON out: %w", err)
	}

	return nil
}

var channelTIDRegex = regexp.MustCompile(`tsid-([0-9]+)`)

func (s *server) handleEventStream(ctx context.Context, rw http.ResponseWriter, r *http.Request) error {
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	if r.URL.Path != "/" {
		return httpError{
			httpStatusCode: http.StatusNotFound,
			publicError:    "not found",
		}
	}
	qs, err := url.ParseQuery(r.URL.RawQuery)
	if err != nil {
		return httpError{
			err:            err,
			httpStatusCode: http.StatusBadRequest,
			publicError:    "can't parse query string",
		}
	}

	var ids []int
	for _, idStr := range qs["id"] {
		id, err := strconv.Atoi(idStr)
		if err != nil {
			return httpError{
				err:            err,
				httpStatusCode: http.StatusBadRequest,
				publicError:    "id parameter bad",
			}
		}
		ids = append(ids, id)
	}
	if len(ids) == 0 {
		return httpError{
			httpStatusCode: http.StatusBadRequest,
			publicError:    "need at least one ?id= to watch",
		}
	}

	listenConn, err := s.dbPool.Acquire(ctx)
	if err != nil {
		return fmt.Errorf("acquiring database connection: %w", err)
	}
	defer listenConn.Release()

	for _, id := range ids {
		_, err := listenConn.Exec(ctx, fmt.Sprintf(`LISTEN "tsid-%d"`, id))
		if err != nil {
			return fmt.Errorf("watching for %d: %w", id, err)
		}
	}

	rw.Header().Set("Content-Type", "text/event-stream")
	rw.Header().Set("Cache-Control", "no-cache")

	encodeSvc := func(conn summarize.Querier, id int) error {
		svc, err := summarize.Service(ctx, conn, id)
		if errors.Is(err, pgx.ErrNoRows) {
			if _, err := fmt.Fprintf(rw, "event: notyet\ndata: %d\n\n", id); err != nil {
				return err
			}
		} else if err != nil {
			return fmt.Errorf("summarize.Service(%d): %w", id, err)
		}

		if _, err := fmt.Fprint(rw, "data: "); err != nil {
			return fmt.Errorf("writing data: prefix for service %d: %w", id, err)
		}
		svcBytes, err := protojson.Marshal(svc)
		if err != nil {
			return fmt.Errorf("protojson.Marshal service %d: %w", id, err)
		}
		if _, err := rw.Write(svcBytes); err != nil {
			return fmt.Errorf("writing protojson for service %d: %w", id, err)
		}

		if _, err := fmt.Fprint(rw, "\n\n"); err != nil {
			return fmt.Errorf("printing separator after service %d: %w", id, err)
		}

		if f, ok := rw.(http.Flusher); ok {
			f.Flush()
		}

		return nil
	}

	for _, id := range ids {
		if err := encodeSvc(listenConn, id); err != nil {
			return fmt.Errorf("initial data for service %d: %w", err)
		}
	}

	type updatedMsg struct {
		TID int
		Err error
	}
	updatedCh := make(chan updatedMsg)

	go func() {
		defer close(updatedCh)
		for {
			n, err := listenConn.Conn().WaitForNotification(ctx)
			if err != nil {
				select {
				case <-ctx.Done():
					return
				case updatedCh <- updatedMsg{Err: err}:
				}
			}

			ms := channelTIDRegex.FindStringSubmatch(n.Channel)
			if len(ms) != 2 {
				log.Printf("unrecognised notification channel %q", n.Channel)
				continue
			}

			id, err := strconv.Atoi(ms[1])
			if err != nil {
				log.Printf("failed to parse TID from notification channel %q: %v", n.Channel, err)
				continue
			}

			select {
			case <-ctx.Done():
				return
			case updatedCh <- updatedMsg{TID: id}:
			}
		}
	}()

	pingCh := time.NewTicker(1 * time.Minute)
	defer pingCh.Stop()

	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-pingCh.C:
			if _, err := fmt.Fprint(rw, "event: ping\n\n"); err != nil {
				return err
			}
			if f, ok := rw.(http.Flusher); ok {
				f.Flush()
			}
		case umsg, ok := <-updatedCh:
			if !ok {
				log.Printf("updatedCh closed, shutting down conn")
				return nil
			}
			if umsg.Err != nil {
				return umsg.Err
			}
			sconn, err := s.dbPool.Acquire(ctx)
			if err != nil {
				return fmt.Errorf("acquiring database connection: %w", err)
			}
			err = encodeSvc(sconn, umsg.TID)
			sconn.Release()
			if err != nil {
				return err
			}
		}
	}
}

func (s *server) handleHTTP(ctx context.Context, rw http.ResponseWriter, r *http.Request) error {
	if r.ProtoMajor == 2 && strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") {
		s.grpcServer.ServeHTTP(rw, r)
		return nil
	}

	switch r.Header.Get("Accept") {
	case "application/json":
		return s.handleJSON(ctx, rw, r)
	case "text/event-stream":
		return s.handleEventStream(ctx, rw, r)
	case "text/plain":
		if *allowStress {
			return s.stress(ctx, rw, r)
		}
		fallthrough
	default:
		return httpError{
			httpStatusCode: http.StatusNotAcceptable,
			publicError:    "Accept should be text/event-stream or application/json",
		}
	}
}

func (s *server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
	wrw := &watchingResponseWriter{ResponseWriter: rw, Flusher: rw.(http.Flusher)}
	if err := s.handleHTTP(r.Context(), wrw, r); err != nil {
		statusCode := http.StatusInternalServerError
		errText := fmt.Sprintf("internal server error: %v", err.Error())

		var he httpError
		if errors.As(err, &he) {
			if he.httpStatusCode != 0 {
				statusCode = he.httpStatusCode
			}
			if he.publicError != "" {
				errText = he.publicError
			}
		}
		log.Printf("[%v] %v %v -- %v (%d %v)", r.RemoteAddr, r.Method, r.RequestURI, err, statusCode, errText)

		if !wrw.WrittenTo {

			for k := range wrw.Header() {
				wrw.Header().Del(k)
			}
			wrw.Header().Set("Content-Type", "text/plain")
			wrw.WriteHeader(statusCode)
			fmt.Fprintf(wrw, "%s\n", errText)
		}
	} else {
		log.Printf("[%v] %v %v -- (%d) %d bytes", r.RemoteAddr, r.Method, r.RequestURI, wrw.WroteStatus, wrw.WroteContentBytes)
	}
}

func main() {
	flag.Parse()
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	cfg, err := pgxpool.ParseConfig("host=/var/run/postgresql database=trains search_path=darwindb")
	if err != nil {
		log.Fatalf("pgxpool.ParseConfig: %v", err)
	}
	cfg.AfterRelease = func(pc *pgx.Conn) bool {
		// true to return to pool, false to destroy
		ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
		defer cancel()

		_, err := pc.Exec(ctx, "UNLISTEN *")
		if err != nil {
			log.Printf("pg connection cleanup: UNLISTEN failed: %v", err)
			return false
		}

		return true
	}

	pcpool, err := pgxpool.ConnectConfig(ctx, cfg)
	if err != nil {
		log.Fatalf("pgxpool.ConnectConfig: %v", err)
	}
	defer pcpool.Close()

	gs := grpc.NewServer()
	webapi.RegisterTrainsServiceServer(gs, &grpcServer{
		dbPool: pcpool,
	})

	s := &server{
		dbPool:     pcpool,
		grpcServer: gs,
	}
	listen := ":13974"
	httpServer := &http.Server{
		Addr:    listen,
		Handler: s,
	}
	if *certFile != "" && *keyFile != "" {
		log.Printf("listening on %s (TLS)", listen)
		log.Fatal(httpServer.ListenAndServeTLS(*certFile, *keyFile))
	} else {
		log.Printf("listening on %s (plaintext HTTP)", listen)
		log.Fatal(httpServer.ListenAndServe())
	}
}