package main

import (
	"context"
	"flag"
	"fmt"
	"log"
	"time"

	firebase "firebase.google.com/go/v4"
	"firebase.google.com/go/v4/messaging"
	"github.com/jackc/pgconn"
	"github.com/jackc/pgx/v4"
)

var (
	projectID = flag.String("project_id", "lukegb-trains", "Firebase project ID")

	batchInterval = flag.Duration("batch_interval", 1*time.Second, "interval at which to batch messages")
	batchMax      = flag.Int("batch_max", 300, "Maximum messages before we will instantly send the batch")
)

func main() {
	flag.Parse()
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	pc, err := pgx.Connect(ctx, ("host=/var/run/postgresql database=trains search_path=darwindb"))
	if err != nil {
		log.Fatalf("pgx.Connect: %v", err)
	}
	defer pc.Close(context.Background())

	fapp, err := firebase.NewApp(ctx, &firebase.Config{
		ProjectID: *projectID,
	})
	if err != nil {
		log.Fatalf("firebase.NewApp: %v", err)
	}

	fcm, err := fapp.Messaging(ctx)
	if err != nil {
		log.Fatalf("fapp.Messaging: %v", err)
	}

	if _, err := pc.Exec(ctx, "LISTEN firehose"); err != nil {
		log.Fatalf("subscribing to firehose: %v", err)
	}

	notificationCh := make(chan *pgconn.Notification)

	go func() {
		defer close(notificationCh)
		for {
			n, err := pc.WaitForNotification(ctx)
			if err != nil {
				log.Fatalf("WaitForNotification: %v", err)
			}

			notificationCh <- n
		}
	}()

	t := time.NewTicker(*batchInterval)

	batch := map[string]bool{}
	sendBatchNow := func(ctx context.Context) error {
		if len(batch) == 0 {
			return nil
		}
		msgs := make([]*messaging.Message, 0, len(batch))
		for b := range batch {
			msgs = append(msgs, &messaging.Message{
				Topic: b,
			})
		}
		br, err := fcm.SendAll(ctx, msgs)
		if err != nil {
			return fmt.Errorf("SendAll: %q", err)
		}
		log.Printf("sent batch: %d success, %d failure", br.SuccessCount, br.FailureCount)
		if br.FailureCount != 0 {
			for _, sr := range br.Responses {
				if sr.Success {
					continue
				}
				log.Printf("message %s: %v", sr.MessageID, sr.Error)
			}
		}

		batch = map[string]bool{}
		return nil
	}
	maybeSendBatch := func(ctx context.Context) error {
		if len(batch) < *batchMax {
			return nil
		}
		return sendBatchNow(ctx)
	}
loop:
	for {
		select {
		case <-ctx.Done():
			break loop
		case <-t.C:
			// send batch
			if err := sendBatchNow(ctx); err != nil {
				log.Printf("sending batch on ticker: %v", err)
			}
		case n := <-notificationCh:
			batch[n.Payload] = true
			if err := maybeSendBatch(ctx); err != nil {
				log.Printf("sending batch because batch is full: %v", err)
			}
		}
	}
}