package summarize

import (
	"context"
	"fmt"
	"strings"

	"github.com/jackc/pgx/v4"
	"hg.lukegb.com/lukegb/depot/go/trains/webapi"
)

type Querier interface {
	Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error)
	QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
}

func serviceWhere(ctx context.Context, pc Querier, limit int, where string, params ...interface{}) ([]*webapi.ServiceData, error) {
	rows, err := pc.Query(ctx, `
SELECT
	ts.id, ts.rid, ts.uid, ts.rsid, ts.headcode, ts.scheduled_start_date::varchar,
	ts.train_operator, COALESCE(rop.name, ''), COALESCE(rop.url, ''),
	COALESCE(ts.delay_reason_code::int, 0), COALESCE(rlr.text, ''), COALESCE(ts.delay_reason_tiploc, ''), COALESCE(rlrloc.name, ''), COALESCE(rlrloc.crs, ''), COALESCE(rlrloc.toc, ''), COALESCE(rlrop.name, ''), COALESCE(rlrop.url, ''), COALESCE(ts.delay_reason_tiploc_near, false),
	COALESCE(ts.cancellation_reason_code::int, 0), COALESCE(rlr.text, ''), COALESCE(ts.cancellation_reason_tiploc, ''), COALESCE(rlrloc.name, ''), COALESCE(rlrloc.crs, ''), COALESCE(rlrloc.toc, ''), COALESCE(rlrop.name, ''), COALESCE(rlrop.url, ''), COALESCE(ts.cancellation_reason_tiploc_near, false),
	ts.active, ts.deleted, ts.cancelled
FROM
	train_services ts
	LEFT JOIN ref_tocs rop ON rop.toc=ts.train_operator

	LEFT JOIN ref_late_running_reasons rlr ON rlr.code=ts.delay_reason_code::int
	LEFT JOIN ref_locations rlrloc ON rlrloc.tiploc=ts.delay_reason_tiploc
	LEFT JOIN ref_tocs rlrop ON rlrop.toc=rlrloc.toc

	LEFT JOIN ref_cancel_reasons rc ON rc.code=ts.cancellation_reason_code::int
	LEFT JOIN ref_locations rcloc ON rcloc.tiploc=ts.cancellation_reason_tiploc
	LEFT JOIN ref_tocs rcop ON rcop.toc=rcloc.toc
WHERE
	`+where+fmt.Sprintf(`
ORDER BY ts.id
LIMIT %d`, limit), params...)
	if err != nil {
		return nil, fmt.Errorf("querying for services: %w", err)
	}
	defer rows.Close()
	var sds []*webapi.ServiceData
	sdMap := make(map[uint64]*webapi.ServiceData)

	for rows.Next() {
		sd := &webapi.ServiceData{
			DelayReason:  &webapi.DisruptionReason{Location: &webapi.Location{Operator: &webapi.TrainOperator{}}},
			CancelReason: &webapi.DisruptionReason{Location: &webapi.Location{Operator: &webapi.TrainOperator{}}},
			Operator:     &webapi.TrainOperator{},
		}
		if err := rows.Scan(
			&sd.Id, &sd.Rid, &sd.Uid, &sd.Rsid, &sd.Headcode, &sd.ScheduledStartDate,
			&sd.Operator.Code, &sd.Operator.Name, &sd.Operator.Url,
			&sd.DelayReason.Code, &sd.DelayReason.Text, &sd.DelayReason.Location.Tiploc, &sd.DelayReason.Location.Name, &sd.DelayReason.Location.Crs, &sd.DelayReason.Location.Operator.Code, &sd.DelayReason.Location.Operator.Name, &sd.DelayReason.Location.Operator.Url, &sd.DelayReason.NearLocation,
			&sd.CancelReason.Code, &sd.CancelReason.Text, &sd.CancelReason.Location.Tiploc, &sd.CancelReason.Location.Name, &sd.CancelReason.Location.Crs, &sd.CancelReason.Location.Operator.Code, &sd.CancelReason.Location.Operator.Name, &sd.CancelReason.Location.Operator.Url, &sd.CancelReason.NearLocation,
			&sd.IsActive, &sd.IsDeleted, &sd.IsCancelled,
		); err != nil {
			return nil, fmt.Errorf("reading from train_services: %w", err)
		}
		sd.Canonicalize()
		sds = append(sds, sd)
		sdMap[sd.Id] = sd
	}
	if rows.Err() != nil {
		return nil, fmt.Errorf("reading rows for services: %w", err)
	}
	rows.Close()

	if len(sds) == 0 {
		return nil, nil
	}

	var idsStrs []string
	for id := range sdMap {
		idsStrs = append(idsStrs, fmt.Sprintf("%d", id))
	}
	// Now for the locations.
	rows, err = pc.Query(ctx, fmt.Sprintf(`
SELECT
	tl.tsid, tl.id,
	tl.tiploc, COALESCE(rloc.name, ''), COALESCE(rloc.crs, ''), COALESCE(rloc.toc, ''), COALESCE(rloctoc.name, ''), COALESCE(rloctoc.url, ''),
	tl.calling_point::varchar, COALESCE(tl.train_length, 0), tl.service_suppressed,
	tl.schedule_cancelled,
	COALESCE(tl.schedule_platform, ''), COALESCE(tl.platform, ''), COALESCE(tl.platform_confirmed, false), COALESCE(tl.platform_suppressed, false),
	tl.schedule_false_destination_tiploc, COALESCE(rfdloc.name, ''), COALESCE(rfdloc.crs, ''), COALESCE(rfdloc.toc, ''), COALESCE(rfdloctoc.name, ''), COALESCE(rfdloctoc.url, ''),

	tl.schedule_public_arrival, tl.schedule_working_arrival, tl.estimated_arrival, tl.working_estimated_arrival, tl.actual_arrival,
	tl.schedule_public_departure, tl.schedule_working_departure, tl.estimated_departure, tl.working_estimated_departure, tl.actual_departure,
	/* no public_pass */ tl.schedule_working_pass, tl.estimated_pass, tl.working_estimated_pass, tl.actual_pass
FROM
	train_locations tl
	LEFT JOIN ref_locations rloc ON rloc.tiploc=tl.tiploc
	LEFT JOIN ref_tocs rloctoc ON rloctoc.toc=rloc.toc

	LEFT JOIN ref_locations rfdloc ON rfdloc.tiploc=tl.schedule_false_destination_tiploc
	LEFT JOIN ref_tocs rfdloctoc ON rfdloctoc.toc=rfdloc.toc
WHERE
	tl.tsid IN (%s) AND tl.active_in_schedule
ORDER BY
	tl.tsid, COALESCE(tl.schedule_working_arrival, tl.schedule_working_pass, tl.schedule_working_departure)
`, strings.Join(idsStrs, ",")))
	if err != nil {
		return nil, fmt.Errorf("querying locations: %w", err)
	}
	defer rows.Close()
	for rows.Next() {
		var tsid uint64
		loc := webapi.ServiceLocation{
			Location:         &webapi.Location{Operator: &webapi.TrainOperator{}},
			Platform:         &webapi.PlatformData{},
			FalseDestination: &webapi.Location{Operator: &webapi.TrainOperator{}},
			ArrivalTiming:    &webapi.TimingData{PublicScheduled: &webapi.DateTime{}, WorkingScheduled: &webapi.DateTime{}, PublicEstimated: &webapi.DateTime{}, WorkingEstimated: &webapi.DateTime{}, Actual: &webapi.DateTime{}},
			DepartureTiming:  &webapi.TimingData{PublicScheduled: &webapi.DateTime{}, WorkingScheduled: &webapi.DateTime{}, PublicEstimated: &webapi.DateTime{}, WorkingEstimated: &webapi.DateTime{}, Actual: &webapi.DateTime{}},
			PassTiming:       &webapi.TimingData{PublicScheduled: &webapi.DateTime{}, WorkingScheduled: &webapi.DateTime{}, PublicEstimated: &webapi.DateTime{}, WorkingEstimated: &webapi.DateTime{}, Actual: &webapi.DateTime{}},
		}
		if err := rows.Scan(
			&tsid, &loc.Id,
			&loc.Location.Tiploc, &loc.Location.Name, &loc.Location.Crs, &loc.Location.Operator.Code, &loc.Location.Operator.Name, &loc.Location.Operator.Url,
			&loc.CallingPointType, &loc.TrainLength, &loc.ServiceSuppressed,
			&loc.Cancelled,
			&loc.Platform.Scheduled, &loc.Platform.Live, &loc.Platform.Confirmed, &loc.Platform.Suppressed,
			&loc.FalseDestination.Tiploc, &loc.FalseDestination.Name, &loc.FalseDestination.Crs, &loc.FalseDestination.Operator.Code, &loc.FalseDestination.Operator.Name, &loc.FalseDestination.Operator.Url,
			loc.ArrivalTiming.PublicScheduled, loc.ArrivalTiming.WorkingScheduled, loc.ArrivalTiming.PublicEstimated, loc.ArrivalTiming.WorkingEstimated, loc.ArrivalTiming.Actual,
			loc.DepartureTiming.PublicScheduled, loc.DepartureTiming.WorkingScheduled, loc.DepartureTiming.PublicEstimated, loc.DepartureTiming.WorkingEstimated, loc.DepartureTiming.Actual,
			/*loc.PassTiming.PublicScheduled,*/ loc.PassTiming.WorkingScheduled, loc.PassTiming.PublicEstimated, loc.PassTiming.WorkingEstimated, loc.PassTiming.Actual,
		); err != nil {
			return nil, fmt.Errorf("scanning locations: %w", err)
		}
		loc.Canonicalize()
		sd := sdMap[tsid]
		sd.Locations = append(sd.Locations, &loc)
	}
	if rows.Err() != nil {
		return nil, fmt.Errorf("iterating over locations: %w", rows.Err())
	}

	return sds, nil
}

func Service(ctx context.Context, pc Querier, id int) (*webapi.ServiceData, error) {
	sds, err := serviceWhere(ctx, pc, 2, "ts.id=$1", id)
	if err != nil {
		return nil, fmt.Errorf("query for service %d: %w", id, err)
	}

	if len(sds) != 1 {
		return nil, fmt.Errorf("query for service %d returned %d services: %w", id, len(sds), err)
	}
	return sds[0], nil
}

func ListServices(ctx context.Context, pc Querier, req *webapi.ListServicesRequest) ([]*webapi.ServiceData, error) {
	var wheres []string
	var vars []interface{}
	addVar := func(x interface{}) string {
		vars = append(vars, x)
		return fmt.Sprintf("$%d", len(vars))
	}
	addVars := func(xs ...interface{}) string {
		xvars := make([]string, len(xs))
		for n, x := range xs {
			xvars[n] = addVar(x)
		}
		return strings.Join(xvars, ",")
	}

	if req.GetRid() != "" {
		wheres = append(wheres, fmt.Sprintf("ts.rid=%s", addVar(req.GetRid())))
	}
	if req.GetUid() != "" {
		wheres = append(wheres, fmt.Sprintf("ts.uid=%s", addVar(req.GetUid())))
	}
	if req.GetHeadcode() != "" {
		wheres = append(wheres, fmt.Sprintf("ts.headcode=%s", addVar(req.GetHeadcode())))
	}
	if req.GetScheduledStartDate() != "" {
		wheres = append(wheres, fmt.Sprintf("ts.scheduled_start_date::varchar=%s", addVar(req.GetScheduledStartDate())))
	}
	if req.GetTrainOperatorCode() != "" {
		wheres = append(wheres, fmt.Sprintf("ts.train_operator=%s", addVar(req.GetTrainOperatorCode())))
	}

	addLocation := func(lr *webapi.LocationRequest, passingPointType ...interface{}) error {
		if lr == nil {
			return nil
		}
		var subwheres []string
		var specified bool
		if lr.GetTiploc() != "" {
			subwheres = append(subwheres, fmt.Sprintf("tl.tiploc=%s", addVar(lr.GetTiploc())))
			specified = true
		}
		if lr.GetCrs() != "" {
			subwheres = append(subwheres, fmt.Sprintf("rl.crs=%s AND rl.crs IS NOT NULL", addVar(lr.GetCrs())))
			specified = true
		}
		if !specified {
			return fmt.Errorf("must specify either tiploc or crs")
		}
		if !lr.GetIncludeInPast() {
			subwheres = append(subwheres, "COALESCE(tl.actual_departure, tl.actual_pass, tl.estimated_departure, tl.estimated_pass, tl.working_estimated_departure, tl.working_estimated_pass) > CURRENT_TIMESTAMP")
		}
		wheres = append(wheres, fmt.Sprintf("ts.id IN (SELECT tl.tsid FROM train_locations tl LEFT JOIN ref_locations rl ON rl.tiploc=tl.tiploc WHERE tl.calling_point IN (%s) AND %s)", addVars(passingPointType...), strings.Join(subwheres, " AND ")))
		return nil
	}
	if err := addLocation(req.GetOrigin(), "OR"); err != nil {
		return nil, fmt.Errorf("origin: %w", err)
	}
	if err := addLocation(req.GetDestination(), "DT"); err != nil {
		return nil, fmt.Errorf("destination: %w", err)
	}
	if err := addLocation(req.GetCallingAt(), "OR", "DT", "IP"); err != nil {
		return nil, fmt.Errorf("calling_at: %w", err)
	}
	if err := addLocation(req.GetPassing(), "OR", "DT", "IP", "PP", "OPIP", "OPOR", "OPDT"); err != nil {
		return nil, fmt.Errorf("passing: %w", err)
	}

	if !req.GetIncludeDeleted() {
		wheres = append(wheres, "NOT ts.deleted")
	}
	limit := req.GetMaxResults()
	if limit == 0 || limit > 100 {
		limit = 100
	}

	return serviceWhere(ctx, pc, int(limit), strings.Join(wheres, " AND "), vars...)
}