depot/go/trains/cmd/darwiningestd/darwiningestd.go

235 lines
7.4 KiB
Go

package main
import (
"context"
"flag"
"fmt"
"log"
"github.com/aws/aws-sdk-go/aws"
awscreds "github.com/aws/aws-sdk-go/aws/credentials"
awssession "github.com/aws/aws-sdk-go/aws/session"
"github.com/go-stomp/stomp/v3"
pgx "github.com/jackc/pgx/v4"
"github.com/jlaffaye/ftp"
"gocloud.dev/blob"
"gocloud.dev/blob/s3blob"
"hg.lukegb.com/lukegb/depot/go/trains/darwin"
"hg.lukegb.com/lukegb/depot/go/trains/darwin/darwindb"
"hg.lukegb.com/lukegb/depot/go/trains/darwin/darwingest"
"hg.lukegb.com/lukegb/depot/go/trains/darwin/darwingest/darwingeststomp"
)
var (
cleanStart = flag.Bool("clean_start", false, "Start by downloading all timetabling and reference data and restart from snapshot")
awsAccessKey = requiredString("s3_access_key", "Darwin XML timetable S3 access key")
awsAccessSecret = requiredString("s3_access_secret", "Darwin XML timetable S3 access secret")
awsBucket = flag.String("s3_bucket", "darwin.xmltimetable", "Darwin XML timetable S3 bucket")
awsPrefix = flag.String("s3_prefix", "PPTimetable/", "Darwin XML timetable S3 bucket key prefix")
postgresString = flag.String("pg_string", "host=/var/run/postgresql database=trains search_path=darwindb", "String to connect to PG")
stompHostport = requiredString("stomp_address", "Darwin STOMP address (host:port)")
stompUsername = requiredString("stomp_username", "Darwin STOMP username")
stompPassword = requiredString("stomp_password", "Darwin STOMP password")
stompClientID = requiredString("stomp_client_id", "Darwin STOMP client ID")
stompTopic = flag.String("stomp_topic", "/topic/darwin.pushport-v16", "Darwin STOMP topic")
ftpHostport = requiredString("ftp_address", "Darwin FTP address (host:port)")
ftpUsername = requiredString("ftp_username", "Darwin FTP username")
ftpPassword = requiredString("ftp_password", "Darwin FTP password")
requiredFlags []string
)
func requiredString(name, descr string) *string {
requireFlag(name)
return flag.String(name, "", descr)
}
func requireFlag(names ...string) {
requiredFlags = append(requiredFlags, names...)
}
func checkRequiredFlags() {
for _, fname := range requiredFlags {
f := flag.CommandLine.Lookup(fname)
if f == nil {
log.Fatalf("required flag %q doesn't actually exist", fname)
}
v := f.Value.(flag.Getter).Get()
if v == nil {
log.Fatalf("required flag %q not set", fname)
}
if v, ok := v.(string); ok && v == "" {
log.Fatalf("required flag %q not set or empty", fname)
}
}
}
func main() {
flag.Parse()
checkRequiredFlags()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
awsSess, err := awssession.NewSession(&aws.Config{
Region: aws.String("eu-west-1"),
Credentials: awscreds.NewStaticCredentials(*awsAccessKey, *awsAccessSecret, ""),
})
if err != nil {
log.Fatalf("awssession.NewSession: %w", err)
}
s3Bucket, err := s3blob.OpenBucket(ctx, awsSess, *awsBucket, nil)
if err != nil {
log.Fatalf("s3blob.OpenBucket: %w", err)
}
defer s3Bucket.Close()
bucket := blob.PrefixedBucket(s3Bucket, *awsPrefix)
defer bucket.Close()
pc, err := pgx.Connect(ctx, *postgresString)
if err != nil {
log.Fatalf("pgx.Connect: %v", err)
}
defer pc.Close(context.Background())
catchupTx, err := pc.Begin(ctx)
if err != nil {
log.Fatalf("beginning catchup transaction: %v", err)
}
catchupAffected := darwindb.NewAffected()
gimmeTransaction := func(ctx context.Context, cb func(ctx context.Context, tx pgx.Tx) error) error {
if catchupTx != nil {
return cb(ctx, catchupTx)
}
tx, err := pc.Begin(ctx)
if err != nil {
return fmt.Errorf("starting transaction: %w", err)
}
defer tx.Rollback(ctx)
if err := cb(ctx, tx); err != nil {
return err
}
return tx.Commit(ctx)
}
notify := func(ctx context.Context, tx pgx.Tx, s string) error {
if _, err := tx.Exec(ctx, fmt.Sprintf("NOTIFY %q", s)); err != nil {
return fmt.Errorf("notifying channel for %q: %w", s, err)
}
if _, err := tx.Exec(ctx, fmt.Sprintf("NOTIFY \"firehose\", '%s'", s)); err != nil {
return fmt.Errorf("notifying firehose for %q: %w", s, err)
}
return nil
}
sendAffectedNotifications := func(ctx context.Context, tx pgx.Tx, affected *darwindb.Affected) error {
for tsid := range affected.TSIDs {
if err := notify(ctx, tx, fmt.Sprintf("tsid-%d", tsid)); err != nil {
return fmt.Errorf("notifying channel for tsid %d: %w", tsid, err)
}
}
for tiploc := range affected.TIPLOCs {
if err := notify(ctx, tx, fmt.Sprintf("tiploc-%s", tiploc)); err != nil {
return fmt.Errorf("notifying channel for tiploc %s: %w", tiploc, err)
}
}
return nil
}
processMessage := func(ctx context.Context, pp *darwin.PushPort) error {
return gimmeTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error {
affected := catchupAffected
sendAffected := false
if affected == nil {
affected = darwindb.NewAffected()
sendAffected = true
}
if err := darwindb.Process(ctx, tx, pp, affected); err != nil {
return fmt.Errorf("darwindb.Process: %w", err)
}
if sendAffected {
return sendAffectedNotifications(ctx, tx, affected)
}
return nil
})
}
cfg := darwingest.Config{
Verbose: true,
DialStomp: func(ctx context.Context) (*stomp.Conn, error) {
return darwingeststomp.Dial(ctx, darwingeststomp.Config{
Hostname: *stompHostport,
Username: *stompUsername,
Password: *stompPassword,
ClientID: *stompClientID,
})
},
StompSubscribe: func(ctx context.Context, c *stomp.Conn) (*stomp.Subscription, error) {
return c.Subscribe(*stompTopic, stomp.AckClientIndividual)
},
DialFTP: func(ctx context.Context) (*ftp.ServerConn, error) {
sc, err := ftp.Dial(*ftpHostport, ftp.DialWithContext(ctx))
if err != nil {
return nil, fmt.Errorf("ftp.Dial: %w", err)
}
if err := sc.Login(*ftpUsername, *ftpPassword); err != nil {
sc.Quit()
return nil, fmt.Errorf("ftp.ServerConn.Login: %w", err)
}
return sc, nil
},
Timetable: func(ppt *darwin.PushPortTimetable) error {
return gimmeTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error {
if err := darwindb.ProcessTimetable(ctx, tx, ppt, catchupAffected); err != nil {
return fmt.Errorf("processing timetable data: %w", err)
}
return nil
})
},
ReferenceData: func(pprd *darwin.PushPortReferenceData) error {
return gimmeTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error {
if err := darwindb.ProcessReferenceData(ctx, tx, pprd, catchupAffected); err != nil {
return fmt.Errorf("processing reference data: %w", err)
}
return nil
})
},
MessageCatchUp: func(pp *darwin.PushPort) error {
return processMessage(ctx, pp)
},
CatchUpComplete: func(ctx context.Context) error {
log.Printf("catchup affected: %v", catchupAffected.Summary())
if err := sendAffectedNotifications(ctx, catchupTx, catchupAffected); err != nil {
return fmt.Errorf("sending affected notifications: %v", err)
}
if err := catchupTx.Commit(ctx); err != nil {
return fmt.Errorf("committing catchup transaction: %v", err)
}
catchupAffected = nil
catchupTx = nil
return nil
},
Message: func(pp *darwin.PushPort) error {
return processMessage(ctx, pp)
},
}
if *cleanStart {
if err := darwingest.CatchUpAndStart(ctx, bucket, cfg); err != nil {
log.Printf("CatchUpAndStart: %v", err)
}
} else {
if err := darwingest.Start(ctx, bucket, cfg); err != nil {
log.Printf("Start: %v", err)
}
}
log.Println("terminating...")
}