depot/go/trains/cmd/db2web/db2web.go
Luke Granger-Brown 7af28f3ef5 go/trains: split summarizeService into its own package
It might come in handy when sending push notifications to devices...
2021-11-20 19:12:47 +00:00

450 lines
10 KiB
Go

package main
import (
"context"
"errors"
"flag"
"fmt"
"log"
"net/http"
"net/url"
"regexp"
"strconv"
"time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/encoding/protojson"
"hg.lukegb.com/lukegb/depot/go/trains/webapi/summarize"
)
var (
allowStress = flag.Bool("allow_stress", false, "Allow checking the validity of every record in the database")
)
type watchingResponseWriter struct {
http.ResponseWriter
http.Flusher
WrittenTo bool
WroteStatus int
WroteContentBytes int
}
func (rw *watchingResponseWriter) Flush() {
if rw.Flusher != nil {
rw.Flusher.Flush()
}
}
func (rw *watchingResponseWriter) Header() http.Header { return rw.ResponseWriter.Header() }
func (rw *watchingResponseWriter) Write(bs []byte) (int, error) {
if !rw.WrittenTo {
rw.WroteStatus = http.StatusOK
}
rw.WrittenTo = true
rw.WroteContentBytes += len(bs)
return rw.ResponseWriter.Write(bs)
}
func (rw *watchingResponseWriter) WriteHeader(statusCode int) {
if rw.WrittenTo {
return
}
rw.WrittenTo = true
rw.WroteStatus = statusCode
rw.ResponseWriter.WriteHeader(statusCode)
}
type httpError struct {
err error
httpStatusCode int
publicError string
}
func (he httpError) Error() string {
if he.err == nil {
return fmt.Sprintf("HTTP %d: %v", he.httpStatusCode, he.publicError)
}
return he.err.Error()
}
type server struct {
dbPool *pgxpool.Pool
}
func (s *server) stress(ctx context.Context, rw http.ResponseWriter, r *http.Request) error {
if r.URL.Path != "/stress" {
return httpError{
httpStatusCode: http.StatusNotFound, publicError: "not found",
}
}
// stress test mode, let's go
const numWorkers = 32
ch := make(chan int, numWorkers)
eg, egCtx := errgroup.WithContext(ctx)
for n := 0; n < numWorkers; n++ {
eg.Go(func() error {
ctx := egCtx
conn, err := s.dbPool.Acquire(ctx)
if err != nil {
return fmt.Errorf("acquiring conn: %w", err)
}
defer conn.Release()
for {
select {
case <-ctx.Done():
return ctx.Err()
case tsid := <-ch:
if tsid == 0 {
return nil
}
svc, err := summarize.Service(ctx, conn, tsid)
if err != nil {
return fmt.Errorf("summarize.Service %v: %w", tsid, err)
}
if _, err := protojson.Marshal(svc); err != nil {
return fmt.Errorf("protojson.Marshal %v: %w", tsid, err)
}
log.Println(tsid)
}
}
return nil
})
}
idConn, err := s.dbPool.Acquire(ctx)
if err != nil {
return fmt.Errorf("acquiring idConn: %w", err)
}
defer idConn.Release()
rows, err := idConn.Query(ctx, "SELECT id FROM train_services ORDER BY id")
if err != nil {
return fmt.Errorf("querying for all IDs: %w", err)
}
defer rows.Close()
loop:
for rows.Next() {
var w int
if err := rows.Scan(&w); err != nil {
return fmt.Errorf("scanning ID: %w", err)
}
select {
case ch <- w:
case <-egCtx.Done():
break loop
}
}
close(ch)
return eg.Wait()
}
var jsonPathRegexp = regexp.MustCompile(`/([1-9][0-9]*)$`)
func (s *server) handleJSON(ctx context.Context, rw http.ResponseWriter, r *http.Request) error {
// /<id>
ms := jsonPathRegexp.FindStringSubmatch(r.URL.Path)
if len(ms) != 2 {
return httpError{
httpStatusCode: http.StatusNotFound,
publicError: "not found",
}
}
id, err := strconv.Atoi(ms[1])
if err != nil {
return httpError{
err: err,
httpStatusCode: http.StatusBadRequest,
publicError: "bad id",
}
}
conn, err := s.dbPool.Acquire(ctx)
if err != nil {
return fmt.Errorf("acquiring database connection: %w", err)
}
defer conn.Release()
svc, err := summarize.Service(ctx, conn, id)
if errors.Is(err, pgx.ErrNoRows) {
return httpError{
err: err,
httpStatusCode: http.StatusNotFound,
publicError: "no train with that ID",
}
} else if err != nil {
return err
}
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
svcBytes, err := protojson.Marshal(svc)
if err != nil {
return fmt.Errorf("encoding JSON: %w", err)
}
if _, err := rw.Write(svcBytes); err != nil {
return fmt.Errorf("writing JSON out: %w", err)
}
return nil
}
var channelTIDRegex = regexp.MustCompile(`tsid-([0-9]+)`)
func (s *server) handleEventStream(ctx context.Context, rw http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
if r.URL.Path != "/" {
return httpError{
httpStatusCode: http.StatusNotFound,
publicError: "not found",
}
}
qs, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
return httpError{
err: err,
httpStatusCode: http.StatusBadRequest,
publicError: "can't parse query string",
}
}
var ids []int
for _, idStr := range qs["id"] {
id, err := strconv.Atoi(idStr)
if err != nil {
return httpError{
err: err,
httpStatusCode: http.StatusBadRequest,
publicError: "id parameter bad",
}
}
ids = append(ids, id)
}
if len(ids) == 0 {
return httpError{
httpStatusCode: http.StatusBadRequest,
publicError: "need at least one ?id= to watch",
}
}
listenConn, err := s.dbPool.Acquire(ctx)
if err != nil {
return fmt.Errorf("acquiring database connection: %w", err)
}
defer listenConn.Release()
for _, id := range ids {
_, err := listenConn.Exec(ctx, fmt.Sprintf(`LISTEN "tsid-%d"`, id))
if err != nil {
return fmt.Errorf("watching for %d: %w", id, err)
}
}
rw.Header().Set("Content-Type", "text/event-stream")
rw.Header().Set("Cache-Control", "no-cache")
encodeSvc := func(conn summarize.Querier, id int) error {
svc, err := summarize.Service(ctx, conn, id)
if errors.Is(err, pgx.ErrNoRows) {
if _, err := fmt.Fprintf(rw, "event: notyet\ndata: %d\n\n", id); err != nil {
return err
}
} else if err != nil {
return fmt.Errorf("summarize.Service(%d): %w", id, err)
}
if _, err := fmt.Fprint(rw, "data: "); err != nil {
return fmt.Errorf("writing data: prefix for service %d: %w", id, err)
}
svcBytes, err := protojson.Marshal(svc)
if err != nil {
return fmt.Errorf("protojson.Marshal service %d: %w", id, err)
}
if _, err := rw.Write(svcBytes); err != nil {
return fmt.Errorf("writing protojson for service %d: %w", id, err)
}
if _, err := fmt.Fprint(rw, "\n\n"); err != nil {
return fmt.Errorf("printing separator after service %d: %w", id, err)
}
if f, ok := rw.(http.Flusher); ok {
f.Flush()
}
return nil
}
for _, id := range ids {
if err := encodeSvc(listenConn, id); err != nil {
return fmt.Errorf("initial data for service %d: %w", err)
}
}
type updatedMsg struct {
TID int
Err error
}
updatedCh := make(chan updatedMsg)
go func() {
defer close(updatedCh)
for {
n, err := listenConn.Conn().WaitForNotification(ctx)
if err != nil {
select {
case <-ctx.Done():
return
case updatedCh <- updatedMsg{Err: err}:
}
}
ms := channelTIDRegex.FindStringSubmatch(n.Channel)
if len(ms) != 2 {
log.Printf("unrecognised notification channel %q", n.Channel)
continue
}
id, err := strconv.Atoi(ms[1])
if err != nil {
log.Printf("failed to parse TID from notification channel %q: %v", n.Channel, err)
continue
}
select {
case <-ctx.Done():
return
case updatedCh <- updatedMsg{TID: id}:
}
}
}()
pingCh := time.NewTicker(1 * time.Minute)
defer pingCh.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-pingCh.C:
if _, err := fmt.Fprint(rw, "event: ping\n\n"); err != nil {
return err
}
if f, ok := rw.(http.Flusher); ok {
f.Flush()
}
case umsg, ok := <-updatedCh:
if !ok {
log.Printf("updatedCh closed, shutting down conn")
return nil
}
if umsg.Err != nil {
return umsg.Err
}
sconn, err := s.dbPool.Acquire(ctx)
if err != nil {
return fmt.Errorf("acquiring database connection: %w", err)
}
err = encodeSvc(sconn, umsg.TID)
sconn.Release()
if err != nil {
return err
}
}
}
}
func (s *server) handleHTTP(ctx context.Context, rw http.ResponseWriter, r *http.Request) error {
accept := r.Header.Get("Accept")
switch accept {
case "application/json":
return s.handleJSON(ctx, rw, r)
case "text/event-stream":
return s.handleEventStream(ctx, rw, r)
case "text/plain":
if *allowStress {
return s.stress(ctx, rw, r)
}
fallthrough
default:
return httpError{
httpStatusCode: http.StatusNotAcceptable,
publicError: "Accept should be text/event-stream or application/json",
}
}
}
func (s *server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
wrw := &watchingResponseWriter{ResponseWriter: rw, Flusher: rw.(http.Flusher)}
if err := s.handleHTTP(r.Context(), wrw, r); err != nil {
statusCode := http.StatusInternalServerError
errText := fmt.Sprintf("internal server error: %v", err.Error())
var he httpError
if errors.As(err, &he) {
if he.httpStatusCode != 0 {
statusCode = he.httpStatusCode
}
if he.publicError != "" {
errText = he.publicError
}
}
log.Printf("[%v] %v %v -- %v (%d %v)", r.RemoteAddr, r.Method, r.RequestURI, err, statusCode, errText)
if !wrw.WrittenTo {
for k := range wrw.Header() {
wrw.Header().Del(k)
}
wrw.Header().Set("Content-Type", "text/plain")
wrw.WriteHeader(statusCode)
fmt.Fprintf(wrw, "%s\n", errText)
}
} else {
log.Printf("[%v] %v %v -- (%d) %d bytes", r.RemoteAddr, r.Method, r.RequestURI, wrw.WroteStatus, wrw.WroteContentBytes)
}
}
func main() {
flag.Parse()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cfg, err := pgxpool.ParseConfig("host=/var/run/postgresql database=trains search_path=darwindb")
if err != nil {
log.Fatalf("pgxpool.ParseConfig: %v", err)
}
cfg.AfterRelease = func(pc *pgx.Conn) bool {
// true to return to pool, false to destroy
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := pc.Exec(ctx, "UNLISTEN *")
if err != nil {
log.Printf("pg connection cleanup: UNLISTEN failed: %v", err)
return false
}
return true
}
pcpool, err := pgxpool.ConnectConfig(ctx, cfg)
if err != nil {
log.Fatalf("pgxpool.ConnectConfig: %v", err)
}
defer pcpool.Close()
s := &server{
dbPool: pcpool,
}
listen := ":13974"
log.Printf("listening on %s", listen)
log.Fatal(http.ListenAndServe(listen, s))
}