package main import ( "context" "flag" "fmt" "log" "github.com/aws/aws-sdk-go/aws" awscreds "github.com/aws/aws-sdk-go/aws/credentials" awssession "github.com/aws/aws-sdk-go/aws/session" "github.com/go-stomp/stomp/v3" pgx "github.com/jackc/pgx/v4" "github.com/jlaffaye/ftp" "gocloud.dev/blob" "gocloud.dev/blob/s3blob" "hg.lukegb.com/lukegb/depot/go/trains/darwin" "hg.lukegb.com/lukegb/depot/go/trains/darwin/darwindb" "hg.lukegb.com/lukegb/depot/go/trains/darwin/darwingest" "hg.lukegb.com/lukegb/depot/go/trains/darwin/darwingest/darwingeststomp" ) var ( cleanStart = flag.Bool("clean_start", false, "Start by downloading all timetabling and reference data and restart from snapshot") awsAccessKey = requiredString("s3_access_key", "Darwin XML timetable S3 access key") awsAccessSecret = requiredString("s3_access_secret", "Darwin XML timetable S3 access secret") awsBucket = flag.String("s3_bucket", "darwin.xmltimetable", "Darwin XML timetable S3 bucket") awsPrefix = flag.String("s3_prefix", "PPTimetable/", "Darwin XML timetable S3 bucket key prefix") postgresString = flag.String("pg_string", "host=/var/run/postgresql database=trains search_path=darwindb", "String to connect to PG") stompHostport = requiredString("stomp_address", "Darwin STOMP address (host:port)") stompUsername = requiredString("stomp_username", "Darwin STOMP username") stompPassword = requiredString("stomp_password", "Darwin STOMP password") stompClientID = requiredString("stomp_client_id", "Darwin STOMP client ID") stompTopic = flag.String("stomp_topic", "/topic/darwin.pushport-v16", "Darwin STOMP topic") ftpHostport = requiredString("ftp_address", "Darwin FTP address (host:port)") ftpUsername = requiredString("ftp_username", "Darwin FTP username") ftpPassword = requiredString("ftp_password", "Darwin FTP password") requiredFlags []string ) func requiredString(name, descr string) *string { requireFlag(name) return flag.String(name, "", descr) } func requireFlag(names ...string) { requiredFlags = append(requiredFlags, names...) } func checkRequiredFlags() { for _, fname := range requiredFlags { f := flag.CommandLine.Lookup(fname) if f == nil { log.Fatalf("required flag %q doesn't actually exist", fname) } v := f.Value.(flag.Getter).Get() if v == nil { log.Fatalf("required flag %q not set", fname) } if v, ok := v.(string); ok && v == "" { log.Fatalf("required flag %q not set or empty", fname) } } } func main() { flag.Parse() checkRequiredFlags() ctx, cancel := context.WithCancel(context.Background()) defer cancel() awsSess, err := awssession.NewSession(&aws.Config{ Region: aws.String("eu-west-1"), Credentials: awscreds.NewStaticCredentials(*awsAccessKey, *awsAccessSecret, ""), }) if err != nil { log.Fatalf("awssession.NewSession: %w", err) } s3Bucket, err := s3blob.OpenBucket(ctx, awsSess, *awsBucket, nil) if err != nil { log.Fatalf("s3blob.OpenBucket: %w", err) } defer s3Bucket.Close() bucket := blob.PrefixedBucket(s3Bucket, *awsPrefix) defer bucket.Close() pc, err := pgx.Connect(ctx, *postgresString) if err != nil { log.Fatalf("pgx.Connect: %v", err) } defer pc.Close(context.Background()) catchupTx, err := pc.Begin(ctx) if err != nil { log.Fatalf("beginning catchup transaction: %v", err) } catchupAffected := darwindb.NewAffected() gimmeTransaction := func(ctx context.Context, cb func(ctx context.Context, tx pgx.Tx) error) error { if catchupTx != nil { return cb(ctx, catchupTx) } tx, err := pc.Begin(ctx) if err != nil { return fmt.Errorf("starting transaction: %w", err) } defer tx.Rollback(ctx) if err := cb(ctx, tx); err != nil { return err } return tx.Commit(ctx) } notify := func(ctx context.Context, tx pgx.Tx, s string) error { if _, err := tx.Exec(ctx, fmt.Sprintf("NOTIFY %q", s)); err != nil { return fmt.Errorf("notifying channel for %q: %w", s, err) } if _, err := tx.Exec(ctx, fmt.Sprintf("NOTIFY \"firehose\", '%s'", s)); err != nil { return fmt.Errorf("notifying firehose for %q: %w", s, err) } return nil } sendAffectedNotifications := func(ctx context.Context, tx pgx.Tx, affected *darwindb.Affected) error { for tsid := range affected.TSIDs { if err := notify(ctx, tx, fmt.Sprintf("tsid-%d", tsid)); err != nil { return fmt.Errorf("notifying channel for tsid %d: %w", tsid, err) } } for tiploc := range affected.TIPLOCs { if err := notify(ctx, tx, fmt.Sprintf("tiploc-%s", tiploc)); err != nil { return fmt.Errorf("notifying channel for tiploc %s: %w", tiploc, err) } } return nil } processMessage := func(ctx context.Context, pp *darwin.PushPort) error { return gimmeTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error { affected := catchupAffected sendAffected := false if affected == nil { affected = darwindb.NewAffected() sendAffected = true } if err := darwindb.Process(ctx, tx, pp, affected); err != nil { return fmt.Errorf("darwindb.Process: %w", err) } if sendAffected { return sendAffectedNotifications(ctx, tx, affected) } return nil }) } cfg := darwingest.Config{ Verbose: true, DialStomp: func(ctx context.Context) (*stomp.Conn, error) { return darwingeststomp.Dial(ctx, darwingeststomp.Config{ Hostname: *stompHostport, Username: *stompUsername, Password: *stompPassword, ClientID: *stompClientID, }) }, StompSubscribe: func(ctx context.Context, c *stomp.Conn) (*stomp.Subscription, error) { return c.Subscribe(*stompTopic, stomp.AckClientIndividual) }, DialFTP: func(ctx context.Context) (*ftp.ServerConn, error) { sc, err := ftp.Dial(*ftpHostport, ftp.DialWithContext(ctx)) if err != nil { return nil, fmt.Errorf("ftp.Dial: %w", err) } if err := sc.Login(*ftpUsername, *ftpPassword); err != nil { sc.Quit() return nil, fmt.Errorf("ftp.ServerConn.Login: %w", err) } return sc, nil }, Timetable: func(ppt *darwin.PushPortTimetable) error { return gimmeTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error { if err := darwindb.ProcessTimetable(ctx, tx, ppt, catchupAffected); err != nil { return fmt.Errorf("processing timetable data: %w", err) } return nil }) }, ReferenceData: func(pprd *darwin.PushPortReferenceData) error { return gimmeTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error { if err := darwindb.ProcessReferenceData(ctx, tx, pprd, catchupAffected); err != nil { return fmt.Errorf("processing reference data: %w", err) } return nil }) }, MessageCatchUp: func(pp *darwin.PushPort) error { return processMessage(ctx, pp) }, CatchUpComplete: func(ctx context.Context) error { log.Printf("catchup affected: %v", catchupAffected.Summary()) if err := sendAffectedNotifications(ctx, catchupTx, catchupAffected); err != nil { return fmt.Errorf("sending affected notifications: %v", err) } if err := catchupTx.Commit(ctx); err != nil { return fmt.Errorf("committing catchup transaction: %v", err) } catchupAffected = nil catchupTx = nil return nil }, Message: func(pp *darwin.PushPort) error { return processMessage(ctx, pp) }, } if *cleanStart { if err := darwingest.CatchUpAndStart(ctx, bucket, cfg); err != nil { log.Printf("CatchUpAndStart: %v", err) } } else { if err := darwingest.Start(ctx, bucket, cfg); err != nil { log.Printf("Start: %v", err) } } log.Println("terminating...") }