package main import ( "context" "encoding/json" "errors" "fmt" "log" "net/http" "net/url" "regexp" "strconv" "time" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "hg.lukegb.com/lukegb/depot/go/trains/webapi" ) type querier interface { Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row } func summarizeService(ctx context.Context, pc querier, id int) (*webapi.ServiceData, error) { row := pc.QueryRow(ctx, ` SELECT ts.id, ts.rid, ts.uid, ts.rsid, ts.headcode, ts.scheduled_start_date::varchar, ts.train_operator, rop.name, rop.url, COALESCE(ts.delay_reason_code::int, 0), COALESCE(rlr.text, ''), COALESCE(ts.delay_reason_tiploc, ''), COALESCE(rlrloc.name, ''), rlrloc.crs, COALESCE(rlrloc.toc, ''), rlrop.name, rlrop.url, COALESCE(ts.delay_reason_tiploc_near, false), COALESCE(ts.cancellation_reason_code::int, 0), COALESCE(rlr.text, ''), COALESCE(ts.cancellation_reason_tiploc, ''), COALESCE(rlrloc.name, ''), rlrloc.crs, COALESCE(rlrloc.toc, ''), rlrop.name, rlrop.url, COALESCE(ts.cancellation_reason_tiploc_near, false), ts.active, ts.deleted, ts.cancelled FROM train_services ts LEFT JOIN ref_tocs rop ON rop.toc=ts.train_operator LEFT JOIN ref_late_running_reasons rlr ON rlr.code=ts.delay_reason_code::int LEFT JOIN ref_locations rlrloc ON rlrloc.tiploc=ts.delay_reason_tiploc LEFT JOIN ref_tocs rlrop ON rlrop.toc=rlrloc.toc LEFT JOIN ref_cancel_reasons rc ON rc.code=ts.cancellation_reason_code::int LEFT JOIN ref_locations rcloc ON rcloc.tiploc=ts.cancellation_reason_tiploc LEFT JOIN ref_tocs rcop ON rcop.toc=rcloc.toc WHERE ts.id=$1 `, id) sd := webapi.ServiceData{ DelayReason: &webapi.DisruptionReason{Location: &webapi.Location{TOC: &webapi.TrainOperator{}}}, CancelReason: &webapi.DisruptionReason{Location: &webapi.Location{TOC: &webapi.TrainOperator{}}}, } if err := row.Scan( &sd.ID, &sd.RID, &sd.UID, &sd.RSID, &sd.Headcode, &sd.StartDate, &sd.TrainOperator.Code, &sd.TrainOperator.Name, &sd.TrainOperator.URL, &sd.DelayReason.Code, &sd.DelayReason.Text, &sd.DelayReason.Location.TIPLOC, &sd.DelayReason.Location.Name, &sd.DelayReason.Location.CRS, &sd.DelayReason.Location.TOC.Code, &sd.DelayReason.Location.TOC.Name, &sd.DelayReason.Location.TOC.URL, &sd.DelayReason.NearLocation, &sd.CancelReason.Code, &sd.CancelReason.Text, &sd.CancelReason.Location.TIPLOC, &sd.CancelReason.Location.Name, &sd.CancelReason.Location.CRS, &sd.CancelReason.Location.TOC.Code, &sd.CancelReason.Location.TOC.Name, &sd.CancelReason.Location.TOC.URL, &sd.CancelReason.NearLocation, &sd.Active, &sd.Deleted, &sd.Cancelled, ); err != nil { return nil, fmt.Errorf("reading from train_services for %d: %w", id, err) } sd.Canonicalize() // Now for the locations. rows, err := pc.Query(ctx, ` SELECT tl.id, tl.tiploc, COALESCE(rloc.name, ''), COALESCE(rloc.crs, ''), COALESCE(rloc.toc, ''), rloctoc.name, rloctoc.url, tl.calling_point::varchar, tl.train_length, tl.service_suppressed, tl.schedule_cancelled, COALESCE(tl.schedule_platform, ''), COALESCE(tl.platform, ''), COALESCE(tl.platform_confirmed, false), COALESCE(tl.platform_suppressed, false), tl.schedule_false_destination_tiploc, COALESCE(rfdloc.name, ''), COALESCE(rfdloc.crs, ''), COALESCE(rfdloc.toc, ''), rfdloctoc.name, rfdloctoc.url, tl.schedule_public_arrival, tl.schedule_working_arrival, tl.estimated_arrival, tl.working_estimated_arrival, tl.actual_arrival, tl.schedule_public_departure, tl.schedule_working_departure, tl.estimated_departure, tl.working_estimated_departure, tl.actual_departure, /* no public_pass */ tl.schedule_working_pass, tl.estimated_pass, tl.working_estimated_pass, tl.actual_pass FROM train_locations tl LEFT JOIN ref_locations rloc ON rloc.tiploc=tl.tiploc LEFT JOIN ref_tocs rloctoc ON rloctoc.toc=rloc.toc LEFT JOIN ref_locations rfdloc ON rfdloc.tiploc=tl.schedule_false_destination_tiploc LEFT JOIN ref_tocs rfdloctoc ON rfdloctoc.toc=rfdloc.toc WHERE tl.tsid=$1 AND tl.active_in_schedule ORDER BY COALESCE(tl.schedule_working_arrival, tl.schedule_working_pass, tl.schedule_working_departure) `, id) if err != nil { return nil, fmt.Errorf("querying locations for %d: %w", id, err) } defer rows.Close() for rows.Next() { loc := webapi.ServiceLocation{ Location: &webapi.Location{TOC: &webapi.TrainOperator{}}, Platform: &webapi.PlatformData{}, FalseDestination: &webapi.Location{TOC: &webapi.TrainOperator{}}, ArrivalTiming: &webapi.TimingData{}, DepartureTiming: &webapi.TimingData{}, PassTiming: &webapi.TimingData{}, } if err := rows.Scan( &loc.ID, &loc.Location.TIPLOC, &loc.Location.Name, &loc.Location.CRS, &loc.Location.TOC.Code, &loc.Location.TOC.Name, &loc.Location.TOC.URL, &loc.CallingPointType, &loc.Length, &loc.Suppressed, &loc.Cancelled, &loc.Platform.Scheduled, &loc.Platform.Live, &loc.Platform.Confirmed, &loc.Platform.Suppressed, &loc.FalseDestination.TIPLOC, &loc.FalseDestination.Name, &loc.FalseDestination.CRS, &loc.FalseDestination.TOC.Code, &loc.FalseDestination.TOC.Name, &loc.FalseDestination.TOC.URL, &loc.ArrivalTiming.PublicScheduled, &loc.ArrivalTiming.WorkingScheduled, &loc.ArrivalTiming.PublicEstimated, &loc.ArrivalTiming.WorkingEstimated, &loc.ArrivalTiming.Actual, &loc.DepartureTiming.PublicScheduled, &loc.DepartureTiming.WorkingScheduled, &loc.DepartureTiming.PublicEstimated, &loc.DepartureTiming.WorkingEstimated, &loc.DepartureTiming.Actual, /*&loc.PassTiming.PublicScheduled,*/ &loc.PassTiming.WorkingScheduled, &loc.PassTiming.PublicEstimated, &loc.PassTiming.WorkingEstimated, &loc.PassTiming.Actual, ); err != nil { return nil, fmt.Errorf("scanning locations for %d: %w", id, err) } loc.Canonicalize() sd.Locations = append(sd.Locations, loc) } if rows.Err() != nil { return nil, fmt.Errorf("iterating over locations for %d: %w", id, err) } return &sd, nil } type watchingResponseWriter struct { http.ResponseWriter http.Flusher WrittenTo bool WroteStatus int WroteContentBytes int } func (rw *watchingResponseWriter) Flush() { if rw.Flusher != nil { rw.Flusher.Flush() } } func (rw *watchingResponseWriter) Header() http.Header { return rw.ResponseWriter.Header() } func (rw *watchingResponseWriter) Write(bs []byte) (int, error) { if !rw.WrittenTo { rw.WroteStatus = http.StatusOK } rw.WrittenTo = true rw.WroteContentBytes += len(bs) return rw.ResponseWriter.Write(bs) } func (rw *watchingResponseWriter) WriteHeader(statusCode int) { if rw.WrittenTo { return } rw.WrittenTo = true rw.WroteStatus = statusCode rw.ResponseWriter.WriteHeader(statusCode) } type httpError struct { err error httpStatusCode int publicError string } func (he httpError) Error() string { if he.err == nil { return fmt.Sprintf("HTTP %d: %v", he.httpStatusCode, he.publicError) } return he.err.Error() } type server struct { dbPool *pgxpool.Pool } var jsonPathRegexp = regexp.MustCompile(`/([1-9][0-9]*)$`) func (s *server) handleJSON(ctx context.Context, rw http.ResponseWriter, r *http.Request) error { // / ms := jsonPathRegexp.FindStringSubmatch(r.URL.Path) if len(ms) != 2 { return httpError{ httpStatusCode: http.StatusNotFound, publicError: "not found", } } id, err := strconv.Atoi(ms[1]) if err != nil { return httpError{ err: err, httpStatusCode: http.StatusBadRequest, publicError: "bad id", } } conn, err := s.dbPool.Acquire(ctx) if err != nil { return fmt.Errorf("acquiring database connection: %w", err) } defer conn.Release() svc, err := summarizeService(ctx, conn, id) if errors.Is(err, pgx.ErrNoRows) { return httpError{ err: err, httpStatusCode: http.StatusNotFound, publicError: "no train with that ID", } } else if err != nil { return err } rw.Header().Set("Content-Type", "application/json; charset=utf-8") if err := json.NewEncoder(rw).Encode(svc); err != nil { return fmt.Errorf("encoding JSON: %w", err) } return nil } var channelTIDRegex = regexp.MustCompile(`tsid-([0-9]+)`) func (s *server) handleEventStream(ctx context.Context, rw http.ResponseWriter, r *http.Request) error { ctx, cancel := context.WithCancel(ctx) defer cancel() if r.URL.Path != "/" { return httpError{ httpStatusCode: http.StatusNotFound, publicError: "not found", } } qs, err := url.ParseQuery(r.URL.RawQuery) if err != nil { return httpError{ err: err, httpStatusCode: http.StatusBadRequest, publicError: "can't parse query string", } } var ids []int for _, idStr := range qs["id"] { id, err := strconv.Atoi(idStr) if err != nil { return httpError{ err: err, httpStatusCode: http.StatusBadRequest, publicError: "id parameter bad", } } ids = append(ids, id) } if len(ids) == 0 { return httpError{ httpStatusCode: http.StatusBadRequest, publicError: "need at least one ?id= to watch", } } listenConn, err := s.dbPool.Acquire(ctx) if err != nil { return fmt.Errorf("acquiring database connection: %w", err) } defer listenConn.Release() for _, id := range ids { _, err := listenConn.Exec(ctx, fmt.Sprintf(`LISTEN "tsid-%d"`, id)) if err != nil { return fmt.Errorf("watching for %d: %w", id, err) } } rw.Header().Set("Content-Type", "text/event-stream") rw.Header().Set("Cache-Control", "no-cache") je := json.NewEncoder(rw) encodeSvc := func(conn querier, id int) error { svc, err := summarizeService(ctx, conn, id) if errors.Is(err, pgx.ErrNoRows) { if _, err := fmt.Fprintf(rw, "event: notyet\ndata: %d\n\n", id); err != nil { return err } } else if err != nil { return fmt.Errorf("summarizeService(%d): %w", id, err) } if _, err := fmt.Fprint(rw, "data: "); err != nil { return fmt.Errorf("writing data: prefix for service %d: %w", id, err) } if err := je.Encode(svc); err != nil { return fmt.Errorf("je.Encode service %d: %w", id, err) } if _, err := fmt.Fprint(rw, "\n\n"); err != nil { return fmt.Errorf("printing separator after service %d: %w", id, err) } if f, ok := rw.(http.Flusher); ok { f.Flush() } return nil } for _, id := range ids { if err := encodeSvc(listenConn, id); err != nil { return fmt.Errorf("initial data for service %d: %w", err) } } type updatedMsg struct { TID int Err error } updatedCh := make(chan updatedMsg) go func() { defer close(updatedCh) for { n, err := listenConn.Conn().WaitForNotification(ctx) if err != nil { select { case <-ctx.Done(): return case updatedCh <- updatedMsg{Err: err}: } } ms := channelTIDRegex.FindStringSubmatch(n.Channel) if len(ms) != 2 { log.Printf("unrecognised notification channel %q", n.Channel) continue } id, err := strconv.Atoi(ms[1]) if err != nil { log.Printf("failed to parse TID from notification channel %q: %v", n.Channel, err) continue } select { case <-ctx.Done(): return case updatedCh <- updatedMsg{TID: id}: } } }() pingCh := time.NewTicker(1 * time.Minute) defer pingCh.Stop() for { select { case <-ctx.Done(): return ctx.Err() case <-pingCh.C: if _, err := fmt.Fprint(rw, "event: ping\n\n"); err != nil { return err } if f, ok := rw.(http.Flusher); ok { f.Flush() } case umsg, ok := <-updatedCh: if !ok { log.Printf("updatedCh closed, shutting down conn") return nil } if umsg.Err != nil { return umsg.Err } sconn, err := s.dbPool.Acquire(ctx) if err != nil { return fmt.Errorf("acquiring database connection: %w", err) } err = encodeSvc(sconn, umsg.TID) sconn.Release() if err != nil { return err } } } } func (s *server) handleHTTP(ctx context.Context, rw http.ResponseWriter, r *http.Request) error { accept := r.Header.Get("Accept") switch accept { case "application/json": return s.handleJSON(ctx, rw, r) case "text/event-stream": return s.handleEventStream(ctx, rw, r) default: return httpError{ httpStatusCode: http.StatusNotAcceptable, publicError: "Accept should be text/event-stream or application/json", } } } func (s *server) ServeHTTP(rw http.ResponseWriter, r *http.Request) { wrw := &watchingResponseWriter{ResponseWriter: rw, Flusher: rw.(http.Flusher)} if err := s.handleHTTP(r.Context(), wrw, r); err != nil { statusCode := http.StatusInternalServerError errText := fmt.Sprintf("internal server error: %v", err.Error()) var he httpError if errors.As(err, &he) { if he.httpStatusCode != 0 { statusCode = he.httpStatusCode } if he.publicError != "" { errText = he.publicError } } log.Printf("[%v] %v %v -- %v (%d %v)", r.RemoteAddr, r.Method, r.RequestURI, err, statusCode, errText) if !wrw.WrittenTo { for k := range wrw.Header() { wrw.Header().Del(k) } wrw.Header().Set("Content-Type", "text/plain") wrw.WriteHeader(statusCode) fmt.Fprintf(wrw, "%s\n", errText) } } else { log.Printf("[%v] %v %v -- (%d) %d bytes", r.RemoteAddr, r.Method, r.RequestURI, wrw.WroteStatus, wrw.WroteContentBytes) } } func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() cfg, err := pgxpool.ParseConfig("host=/var/run/postgresql database=trains search_path=darwindb") if err != nil { log.Fatalf("pgxpool.ParseConfig: %v", err) } cfg.AfterRelease = func(pc *pgx.Conn) bool { // true to return to pool, false to destroy ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() _, err := pc.Exec(ctx, "UNLISTEN *") if err != nil { log.Printf("pg connection cleanup: UNLISTEN failed: %v", err) return false } return true } pcpool, err := pgxpool.ConnectConfig(ctx, cfg) if err != nil { log.Fatalf("pgxpool.ConnectConfig: %v", err) } defer pcpool.Close() s := &server{ dbPool: pcpool, } listen := ":13974" log.Printf("listening on %s", listen) log.Fatal(http.ListenAndServe(listen, s)) }