235 lines
7.4 KiB
Go
235 lines
7.4 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
awscreds "github.com/aws/aws-sdk-go/aws/credentials"
|
|
awssession "github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/go-stomp/stomp/v3"
|
|
pgx "github.com/jackc/pgx/v4"
|
|
"github.com/jlaffaye/ftp"
|
|
"gocloud.dev/blob"
|
|
"gocloud.dev/blob/s3blob"
|
|
"hg.lukegb.com/lukegb/depot/go/trains/darwin"
|
|
"hg.lukegb.com/lukegb/depot/go/trains/darwin/darwindb"
|
|
"hg.lukegb.com/lukegb/depot/go/trains/darwin/darwingest"
|
|
"hg.lukegb.com/lukegb/depot/go/trains/darwin/darwingest/darwingeststomp"
|
|
)
|
|
|
|
var (
|
|
cleanStart = flag.Bool("clean_start", false, "Start by downloading all timetabling and reference data and restart from snapshot")
|
|
|
|
awsAccessKey = requiredString("s3_access_key", "Darwin XML timetable S3 access key")
|
|
awsAccessSecret = requiredString("s3_access_secret", "Darwin XML timetable S3 access secret")
|
|
awsBucket = flag.String("s3_bucket", "darwin.xmltimetable", "Darwin XML timetable S3 bucket")
|
|
awsPrefix = flag.String("s3_prefix", "PPTimetable/", "Darwin XML timetable S3 bucket key prefix")
|
|
|
|
postgresString = flag.String("pg_string", "host=/var/run/postgresql database=trains search_path=darwindb", "String to connect to PG")
|
|
|
|
stompHostport = requiredString("stomp_address", "Darwin STOMP address (host:port)")
|
|
stompUsername = requiredString("stomp_username", "Darwin STOMP username")
|
|
stompPassword = requiredString("stomp_password", "Darwin STOMP password")
|
|
stompClientID = requiredString("stomp_client_id", "Darwin STOMP client ID")
|
|
stompTopic = flag.String("stomp_topic", "/topic/darwin.pushport-v16", "Darwin STOMP topic")
|
|
|
|
ftpHostport = requiredString("ftp_address", "Darwin FTP address (host:port)")
|
|
ftpUsername = requiredString("ftp_username", "Darwin FTP username")
|
|
ftpPassword = requiredString("ftp_password", "Darwin FTP password")
|
|
|
|
requiredFlags []string
|
|
)
|
|
|
|
func requiredString(name, descr string) *string {
|
|
requireFlag(name)
|
|
return flag.String(name, "", descr)
|
|
}
|
|
|
|
func requireFlag(names ...string) {
|
|
requiredFlags = append(requiredFlags, names...)
|
|
}
|
|
|
|
func checkRequiredFlags() {
|
|
for _, fname := range requiredFlags {
|
|
f := flag.CommandLine.Lookup(fname)
|
|
if f == nil {
|
|
log.Fatalf("required flag %q doesn't actually exist", fname)
|
|
}
|
|
v := f.Value.(flag.Getter).Get()
|
|
if v == nil {
|
|
log.Fatalf("required flag %q not set", fname)
|
|
}
|
|
if v, ok := v.(string); ok && v == "" {
|
|
log.Fatalf("required flag %q not set or empty", fname)
|
|
}
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
checkRequiredFlags()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
awsSess, err := awssession.NewSession(&aws.Config{
|
|
Region: aws.String("eu-west-1"),
|
|
Credentials: awscreds.NewStaticCredentials(*awsAccessKey, *awsAccessSecret, ""),
|
|
})
|
|
if err != nil {
|
|
log.Fatalf("awssession.NewSession: %w", err)
|
|
}
|
|
s3Bucket, err := s3blob.OpenBucket(ctx, awsSess, *awsBucket, nil)
|
|
if err != nil {
|
|
log.Fatalf("s3blob.OpenBucket: %w", err)
|
|
}
|
|
defer s3Bucket.Close()
|
|
|
|
bucket := blob.PrefixedBucket(s3Bucket, *awsPrefix)
|
|
defer bucket.Close()
|
|
|
|
pc, err := pgx.Connect(ctx, *postgresString)
|
|
if err != nil {
|
|
log.Fatalf("pgx.Connect: %v", err)
|
|
}
|
|
defer pc.Close(context.Background())
|
|
|
|
catchupTx, err := pc.Begin(ctx)
|
|
if err != nil {
|
|
log.Fatalf("beginning catchup transaction: %v", err)
|
|
}
|
|
catchupAffected := darwindb.NewAffected()
|
|
|
|
gimmeTransaction := func(ctx context.Context, cb func(ctx context.Context, tx pgx.Tx) error) error {
|
|
if catchupTx != nil {
|
|
return cb(ctx, catchupTx)
|
|
}
|
|
|
|
tx, err := pc.Begin(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("starting transaction: %w", err)
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
if err := cb(ctx, tx); err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit(ctx)
|
|
}
|
|
notify := func(ctx context.Context, tx pgx.Tx, s string) error {
|
|
if _, err := tx.Exec(ctx, fmt.Sprintf("NOTIFY %q", s)); err != nil {
|
|
return fmt.Errorf("notifying channel for %q: %w", s, err)
|
|
}
|
|
if _, err := tx.Exec(ctx, fmt.Sprintf("NOTIFY \"firehose\", '%s'", s)); err != nil {
|
|
return fmt.Errorf("notifying firehose for %q: %w", s, err)
|
|
}
|
|
return nil
|
|
}
|
|
sendAffectedNotifications := func(ctx context.Context, tx pgx.Tx, affected *darwindb.Affected) error {
|
|
for tsid := range affected.TSIDs {
|
|
if err := notify(ctx, tx, fmt.Sprintf("tsid-%d", tsid)); err != nil {
|
|
return fmt.Errorf("notifying channel for tsid %d: %w", tsid, err)
|
|
}
|
|
}
|
|
for tiploc := range affected.TIPLOCs {
|
|
if err := notify(ctx, tx, fmt.Sprintf("tiploc-%s", tiploc)); err != nil {
|
|
return fmt.Errorf("notifying channel for tiploc %s: %w", tiploc, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
processMessage := func(ctx context.Context, pp *darwin.PushPort) error {
|
|
return gimmeTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error {
|
|
affected := catchupAffected
|
|
sendAffected := false
|
|
if affected == nil {
|
|
affected = darwindb.NewAffected()
|
|
sendAffected = true
|
|
}
|
|
if err := darwindb.Process(ctx, tx, pp, affected); err != nil {
|
|
return fmt.Errorf("darwindb.Process: %w", err)
|
|
}
|
|
if sendAffected {
|
|
return sendAffectedNotifications(ctx, tx, affected)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
cfg := darwingest.Config{
|
|
Verbose: true,
|
|
DialStomp: func(ctx context.Context) (*stomp.Conn, error) {
|
|
return darwingeststomp.Dial(ctx, darwingeststomp.Config{
|
|
Hostname: *stompHostport,
|
|
Username: *stompUsername,
|
|
Password: *stompPassword,
|
|
ClientID: *stompClientID,
|
|
})
|
|
},
|
|
StompSubscribe: func(ctx context.Context, c *stomp.Conn) (*stomp.Subscription, error) {
|
|
return c.Subscribe(*stompTopic, stomp.AckClientIndividual)
|
|
},
|
|
DialFTP: func(ctx context.Context) (*ftp.ServerConn, error) {
|
|
sc, err := ftp.Dial(*ftpHostport, ftp.DialWithContext(ctx))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ftp.Dial: %w", err)
|
|
}
|
|
|
|
if err := sc.Login(*ftpUsername, *ftpPassword); err != nil {
|
|
sc.Quit()
|
|
return nil, fmt.Errorf("ftp.ServerConn.Login: %w", err)
|
|
}
|
|
return sc, nil
|
|
},
|
|
|
|
Timetable: func(ppt *darwin.PushPortTimetable) error {
|
|
return gimmeTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error {
|
|
if err := darwindb.ProcessTimetable(ctx, tx, ppt, catchupAffected); err != nil {
|
|
return fmt.Errorf("processing timetable data: %w", err)
|
|
}
|
|
return nil
|
|
})
|
|
},
|
|
ReferenceData: func(pprd *darwin.PushPortReferenceData) error {
|
|
return gimmeTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error {
|
|
if err := darwindb.ProcessReferenceData(ctx, tx, pprd, catchupAffected); err != nil {
|
|
return fmt.Errorf("processing reference data: %w", err)
|
|
}
|
|
return nil
|
|
})
|
|
},
|
|
MessageCatchUp: func(pp *darwin.PushPort) error {
|
|
return processMessage(ctx, pp)
|
|
},
|
|
CatchUpComplete: func(ctx context.Context) error {
|
|
log.Printf("catchup affected: %v", catchupAffected.Summary())
|
|
if err := sendAffectedNotifications(ctx, catchupTx, catchupAffected); err != nil {
|
|
return fmt.Errorf("sending affected notifications: %v", err)
|
|
}
|
|
if err := catchupTx.Commit(ctx); err != nil {
|
|
return fmt.Errorf("committing catchup transaction: %v", err)
|
|
}
|
|
catchupAffected = nil
|
|
catchupTx = nil
|
|
return nil
|
|
},
|
|
Message: func(pp *darwin.PushPort) error {
|
|
return processMessage(ctx, pp)
|
|
},
|
|
}
|
|
if *cleanStart {
|
|
if err := darwingest.CatchUpAndStart(ctx, bucket, cfg); err != nil {
|
|
log.Printf("CatchUpAndStart: %v", err)
|
|
}
|
|
} else {
|
|
if err := darwingest.Start(ctx, bucket, cfg); err != nil {
|
|
log.Printf("Start: %v", err)
|
|
}
|
|
}
|
|
|
|
log.Println("terminating...")
|
|
}
|