package darwindb

import (
	"context"
	"fmt"
	"log"
	"strings"
	"time"

	pgx "github.com/jackc/pgx/v4"
	"hg.lukegb.com/lukegb/depot/go/trains/darwin"
)

// handleTrainStatus handles a Darwin "TS" (train status) message.
func handleTrainStatus(ctx context.Context, tx pgx.Tx, ts *darwin.TrainStatus, a *Affected) error {
	var tsid int
	var tsEST time.Time
	if ts.LateReason != nil {
		row := tx.QueryRow(ctx, `
UPDATE train_services
SET delay_reason_code=$2, delay_reason_tiploc=$3, delay_reason_tiploc_near=$4
WHERE rid=$1
RETURNING id, effective_start_time
`, ts.RID, ts.LateReason.Code, ts.LateReason.TIPLOC, ts.LateReason.Near)
		if err := row.Scan(&tsid, &tsEST); err == pgx.ErrNoRows {
			log.Printf("got TrainStatus (delayed) for RID %v which isn't known", ts.RID)
			return nil
		} else if err != nil {
			return fmt.Errorf("updating delay reason: %w", err)
		}
	} else {
		row := tx.QueryRow(ctx, "SELECT id, effective_start_time FROM train_services WHERE rid=$1", ts.RID)
		if err := row.Scan(&tsid, &tsEST); err == pgx.ErrNoRows {
			log.Printf("got TrainStatus for RID %v which isn't known", ts.RID)
			return nil
		} else if err != nil {
			return fmt.Errorf("fetching train_services ID: %w", err)
		}
	}
	a.RID(ts.RID)
	a.TSID(tsid)

	for _, l := range ts.Location {
		// OK, this is a bit hairy. We only get updates for things which have _changed_.
		var els []interface{}
		addEl := func(x interface{}) int {
			els = append(els, x)
			return len(els)
		}
		addSet := func(s string, x interface{}) string {
			return fmt.Sprintf("%s=$%d", s, addEl(x))
		}
		addTSTimeData := func(prefix string, td *darwin.TSTimeData) []string {
			if td == nil {
				return nil
			}
			x := []interface{}{
				"estimated_%s", timeToTimestamp(td.ET, tsEST),
				"working_estimated_%s", timeToTimestamp(td.WET, tsEST),
				"actual_%s", timeToTimestamp(td.AT, tsEST),
				"%s_data", td,
			}
			out := make([]string, len(x)/2)
			for n := 0; n < len(out); n++ {
				out[n] = fmt.Sprintf(x[n*2].(string)+"=$%d", prefix, addEl(x[n*2+1]))
			}
			return out
		}
		var setClause []string
		setClause = append(setClause, addTSTimeData("arrival", l.Arr)...)
		setClause = append(setClause, addTSTimeData("departure", l.Dep)...)
		setClause = append(setClause, addTSTimeData("pass", l.Pass)...)
		if l.Plat != nil {
			setClause = append(setClause, addSet("platform", l.Plat.Platform))
			setClause = append(setClause, addSet("platform_suppressed", l.Plat.PlatSup))
			setClause = append(setClause, addSet("cis_platform_suppression", l.Plat.CisPlatSup))
			platsrc := map[string]string{
				"P": "planned",
				"A": "automatic",
				"M": "manual",
				"":  "planned",
			}[l.Plat.PlatSrc]
			setClause = append(setClause, addSet("platform_source", platsrc))
			setClause = append(setClause, addSet("platform_confirmed", l.Plat.Conf))
		}
		if l.Suppr != nil {
			setClause = append(setClause, addSet("service_suppressed", *l.Suppr))
		}
		if l.Length != nil {
			setClause = append(setClause, addSet("train_length", *l.Length))
		}
		if len(setClause) == 0 {
			// Nothing to set for this location.
			continue
		}

		var whereClause []string
		for columnName, tsValue := range map[string]string{
			"schedule_public_arrival":    l.PTA,
			"schedule_working_arrival":   l.WTA,
			"schedule_public_departure":  l.PTD,
			"schedule_working_departure": l.WTD,
			"schedule_working_pass":      l.WTP,
		} {
			if tsValue == "" {
				continue
			}
			whereClause = append(whereClause, addSet(columnName, timeToTimestamp(tsValue, tsEST)))
		}
		if len(whereClause) == 0 {
			return fmt.Errorf("RID %v/my ID %v at TIPLOC %v: ended up with no CircularTimes-based WHERE", ts.RID, tsid, l.TIPLOC)
		}
		whereClause = append(whereClause, addSet("tiploc", l.TIPLOC))
		whereClause = append(whereClause, addSet("tsid", tsid))

		query := fmt.Sprintf("UPDATE train_locations SET %s WHERE active_in_schedule AND %s", strings.Join(setClause, ", "), strings.Join(whereClause, " AND "))
		cmd, err := tx.Exec(ctx, query, els...)
		if err != nil {
			log.Printf("RID %v/my ID %v at TIPLOC %v: failed to update with query %q: %v", ts.RID, tsid, l.TIPLOC, query, err)
			return fmt.Errorf("RID %v/my ID %v at TIPLOC %v: failed to update with query %q: %w", ts.RID, tsid, l.TIPLOC, query, err)
		}
		a.TIPLOC(l.TIPLOC)
		if cmd.RowsAffected() > 1 {
			return fmt.Errorf("RID %v/my ID %v at TIPLOC %v: query %q: wanted to update 1 row, updated %d", ts.RID, tsid, l.TIPLOC, query, cmd.RowsAffected())
		} else if cmd.RowsAffected() == 0 {
			log.Printf("RID %v/my ID %v at TIPLOC %v: query %q: wanted to update 1 row, updated %d", ts.RID, tsid, l.TIPLOC, query, cmd.RowsAffected())
			// non-fatal :((
		}
	}
	fmt.Printf("t")

	return nil
}