package darwingest

import (
	"bytes"
	"compress/gzip"
	"context"
	"encoding/xml"
	"fmt"
	"io"
	"log"
	"strings"
	"time"

	"github.com/cenkalti/backoff/v4"
	"github.com/go-stomp/stomp/v3"
	"github.com/jlaffaye/ftp"
	"gocloud.dev/blob"
	"hg.lukegb.com/lukegb/depot/go/trains/darwin"
	"hg.lukegb.com/lukegb/depot/go/trains/darwin/darwingest/darwingestftp"
	"hg.lukegb.com/lukegb/depot/go/trains/darwin/darwingest/darwingests3"
)

func degzip(b []byte) ([]byte, error) {
	gzr, err := gzip.NewReader(bytes.NewBuffer(b))
	if err != nil {
		return nil, fmt.Errorf("gzip.NewReader: %w", err)
	}
	d, err := io.ReadAll(gzr)
	if err != nil {
		return nil, fmt.Errorf("io.ReadAll: %w", err)
	}
	return d, nil
}

func unmarshalGzipPushPort(m []byte) (*darwin.PushPort, error) {
	d, err := degzip(m)
	if err != nil {
		return nil, fmt.Errorf("degzip: %w", err)
	}

	var pp darwin.PushPort
	if err := xml.Unmarshal(d, &pp); err != nil {
		return nil, fmt.Errorf("xml.Unmarshal(%q): %w", string(d), err)
	}
	return &pp, nil
}

func ridToDate(rid string) (time.Time, error) {
	t, err := time.ParseInLocation("20060102", rid[0:8], darwin.London)
	if err != nil {
		return time.Time{}, err
	}
	if t.Hour() < 2 || (t.Hour() == 2 && t.Minute() < 45) {
		t = t.AddDate(0, 0, -1)
	}
	return t, nil
}

type Config struct {
	Verbose bool

	DialStomp       func(ctx context.Context) (*stomp.Conn, error)
	StompSubscribe  func(ctx context.Context, sc *stomp.Conn) (*stomp.Subscription, error)
	DialFTP         func(ctx context.Context) (*ftp.ServerConn, error)
	Message         func(pp *darwin.PushPort) error
	MessageCatchUp  func(pp *darwin.PushPort) error
	Timetable       func(ppt *darwin.PushPortTimetable) error
	ReferenceData   func(pprd *darwin.PushPortReferenceData) error
	CatchUpComplete func(ctx context.Context) error
}

func Start(ctx context.Context, bucket *blob.Bucket, cfg Config) error {
	return run(ctx, bucket, cfg, false)
}

func CatchUpAndStart(ctx context.Context, bucket *blob.Bucket, cfg Config) error {
	return run(ctx, bucket, cfg, true)
}

func run(ctx context.Context, bucket *blob.Bucket, cfg Config, runCatchup bool) error {
	handleMsg := func(cb func(*darwin.PushPort) error) func(*darwin.PushPort) error {
		return func(pp *darwin.PushPort) error {
			for _, ttid := range pp.TimeTableID {
				fn := strings.TrimSpace(ttid.Filename)
				if fn != "" {
					log.Printf("informed that there's a new timetable at %q", fn)
					ppt, err := darwingests3.LoadTimetable(ctx, bucket, fn)
					if err == darwingests3.ErrBadVersion {
						continue
					}
					if err != nil {
						return fmt.Errorf("LoadTimetable(%q): %w", fn, err)
					}
					if err := cfg.Timetable(ppt); err != nil {
						return fmt.Errorf("Timetable(%q): %w", fn, err)
					}
				}

				fn = strings.TrimSpace(ttid.ReferenceFilename)
				if fn != "" {
					log.Printf("informed that there's a new refsheet at %q", fn)
					rd, err := darwingests3.LoadReferenceData(ctx, bucket, fn)
					if err == darwingests3.ErrBadVersion {
						continue
					}
					if err != nil {
						return fmt.Errorf("LoadReferenceData(%q): %w", fn, err)
					}
					if err := cfg.ReferenceData(rd); err != nil {
						return fmt.Errorf("cfg.ReferenceData(%q): %w", fn, err)
					}
				}
			}
			return cb(pp)
		}
	}

	if runCatchup {
		// Load the latest reference data.
		rd, err := darwingests3.LoadLatestReferenceData(ctx, bucket)
		if err != nil {
			return fmt.Errorf("LoadLatestReferenceData: %w", err)
		}
		if err := cfg.ReferenceData(rd); err != nil {
			return fmt.Errorf("cfg.ReferenceData: %w", err)
		}
	}

	stompBackoff := backoff.NewExponentialBackOff()
	stompBackoff.MaxElapsedTime = 24 * time.Hour

	stompConn, err := cfg.DialStomp(ctx)
	if err != nil {
		return fmt.Errorf("DialStomp: %w", err)
	}
	defer stompConn.Disconnect()

	sub, err := cfg.StompSubscribe(ctx, stompConn)
	if err != nil {
		return fmt.Errorf("StompSubscribe: %w", err)
	}
	defer sub.Unsubscribe()

	drainConnection := func() {
		if sub == nil {
			return
		}
		go sub.Unsubscribe()
		for m := range sub.C {
			stompConn.Nack(m)
		}
	}

	if runCatchup {
		if err := func() error {
			log.Println("awaiting first message to find replay timestamp")

			// Find the start timestamp we need to replay up to.
			m := <-sub.C
			defer stompConn.Nack(m)
			d, err := unmarshalGzipPushPort(m.Body)
			if err != nil {
				return fmt.Errorf("unmarshalGzipPushPort: %w", err)
			}
			log.Printf("first message is at %v", d.Time)
			// Don't handle the message just yet.

			rids := d.AffectedRIDs()
			var timetableDate time.Time
			if len(rids) == 0 {
				log.Printf("no affected RIDs in first message! falling back to time.Now...")
				timetableDate = time.Now()
			} else {
				timetableDate, err = ridToDate(rids[0])
				if err != nil {
					log.Printf("couldn't parse RID %q to date, falling back to time.Now: %v", rids[0], err)
					timetableDate = time.Now()
				}
			}

			// Load the timetable first.
			log.Println("loading initial timetable")
			ppt, err := darwingests3.LoadTimetableForDate(ctx, bucket, timetableDate)
			if err != nil {
				return fmt.Errorf("LoadTimetableForDate: %w", err)
			}

			log.Println("feeding timetable to database")
			if err := cfg.Timetable(ppt); err != nil {
				return fmt.Errorf("Timetable: %w", err)
			}

			log.Println("connecting to FTP")
			sc, err := cfg.DialFTP(ctx)
			if err != nil {
				return fmt.Errorf("DialFTP: %w", err)
			}
			defer sc.Quit()

			// Grab the full snapshot dump
			log.Println("loading snapshot")
			lw, _, err := darwingestftp.LoadSnapshot(ctx, sc, handleMsg(cfg.MessageCatchUp))
			if err != nil {
				return fmt.Errorf("LoadSnapshot: %w", err)
			}

			// Now try to "catch it up" using the push-port logs to the start timestamp of the message...
			log.Println("loading logs")
			if err := darwingestftp.LoadLogs(ctx, sc, lw, d.Time, handleMsg(cfg.MessageCatchUp)); err != nil {
				return fmt.Errorf("LoadLogs: %w", err)
			}

			log.Println("marking catchup complete")
			if err := cfg.CatchUpComplete(ctx); err != nil {
				return fmt.Errorf("CatchUpComplete callback: %w", err)
			}

			log.Println("we think we've caught up! let's go...")
			if err := handleMsg(cfg.Message)(d); err != nil {
				return fmt.Errorf("replaying first pushport message: %w", err)
			}
			if err := stompConn.Ack(m); err != nil {
				return fmt.Errorf("acking first pushport message: %w", err)
			}
			return nil
		}(); err != nil {
			drainConnection()
			return err
		}
	} else {
		if err := cfg.CatchUpComplete(ctx); err != nil {
			return fmt.Errorf("CatchUpComplete callback: %w", err)
		}
	}

	mustReconnect := false
	for {
		if sub == nil {
			mustReconnect = true
		}
		if mustReconnect {
			duration := stompBackoff.NextBackOff()
			if duration == backoff.Stop {
				drainConnection()
				return fmt.Errorf("exceeded stomp reconnection retries")
			}
			log.Printf("subscription appears to have closed, retrying after %v...", duration)
			if err := darwin.Sleep(ctx, duration); err != nil {
				drainConnection()
				return err
			}

			var err error
			log.Printf("resubscribing...")
			if stompConn != nil {
				sub, err = cfg.StompSubscribe(ctx, stompConn)
			} else {
				err = fmt.Errorf("no stompConn")
			}
			if err != nil {
				log.Printf("StompSubscribe while resubscribing: %v", err)
				// Force close the server connection and dial again.
				if stompConn != nil {
					stompConn.MustDisconnect()
				}
				stompConn, err = cfg.DialStomp(ctx)
				if err != nil {
					log.Printf("DialStomp failed as well :(")
					// Try again later.
					continue
				}
				defer stompConn.Disconnect()

				sub, err = cfg.StompSubscribe(ctx, stompConn)
				if err != nil {
					log.Printf("StompSubscribe with new connection failed: %v", err)
					// Try again later.
					continue
				}
			}
			defer sub.Unsubscribe()
			mustReconnect = false
		}
		select {
		case <-ctx.Done():
			return ctx.Err()
		case m, ok := <-sub.C:
			if !ok {
				mustReconnect = true
				continue
			}
			d, err := unmarshalGzipPushPort(m.Body)
			if err != nil {
				log.Printf("unmarshalGzipPushPort: %v", err)
				if err := stompConn.Nack(m); err != nil {
					log.Printf("attempting to nack unparsable message: %v", err)
				}
				continue
			}

			if err := handleMsg(cfg.Message)(d); err != nil {
				log.Printf("Message: %v", err)
				if err := stompConn.Nack(m); err != nil {
					log.Printf("attempting to nack unhandled message: %v", err)
				}
				continue
			}
			if err := stompConn.Ack(m); err != nil {
				log.Printf("attempting to ack handled message: %v", err)
				continue
			}

			stompBackoff.Reset()
		}
	}
}