diff --git a/go/trains/cmd/db2web/db2web.go b/go/trains/cmd/db2web/db2web.go index aaf0d61770..65b0d61b69 100644 --- a/go/trains/cmd/db2web/db2web.go +++ b/go/trains/cmd/db2web/db2web.go @@ -18,6 +18,7 @@ import ( "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/protobuf/encoding/protojson" + "hg.lukegb.com/lukegb/depot/go/trains/webapi" "hg.lukegb.com/lukegb/depot/go/trains/webapi/summarize" ) @@ -73,6 +74,76 @@ func (he httpError) Error() string { return he.err.Error() } +type grpcServer struct { + webapi.UnimplementedTrainsServiceServer + + dbPool *pgxpool.Pool +} + +func (s *grpcServer) SummarizeService(ctx context.Context, req *webapi.SummarizeServiceRequest) (*webapi.ServiceData, error) { + conn, err := s.dbPool.Acquire(ctx) + if err != nil { + return nil, fmt.Errorf("dbPool.Acquire: %w", err) + } + defer conn.Release() + + return summarize.Service(ctx, conn, int(req.GetId())) +} + +func (s *grpcServer) ListServices(ctx context.Context, req *webapi.ListServicesRequest) (*webapi.ServiceList, error) { + conn, err := s.dbPool.Acquire(ctx) + if err != nil { + return nil, fmt.Errorf("dbPool.Acquire: %w", err) + } + defer conn.Release() + + sds, err := summarize.ListServices(ctx, conn, req) + if err != nil { + return nil, fmt.Errorf("ListServices: %w", err) + } + + return &webapi.ServiceList{Services: sds}, nil +} + +func (s *grpcServer) ListLocations(ctx context.Context, req *webapi.ListLocationsRequest) (*webapi.LocationList, error) { + conn, err := s.dbPool.Acquire(ctx) + if err != nil { + return nil, fmt.Errorf("dbPool.Acquire: %w", err) + } + defer conn.Release() + + var locs []*webapi.Location + where := "TRUE" + if !req.GetIncludeWithoutCrs() { + where = "rl.crs IS NOT NULL AND rl.crs <> ''" + } + rows, err := conn.Query(ctx, fmt.Sprintf(` +SELECT + rl.tiploc, rl.name, rl.crs, + rltoc.toc, rltoc.name, rltoc.url +FROM + ref_locations rl +INNER JOIN + ref_tocs rltoc ON rltoc.toc = rl.toc +WHERE %s`, where)) + if err != nil { + return nil, fmt.Errorf("querying for locations: %w", err) + } + for rows.Next() { + loc := &webapi.Location{Operator: &webapi.TrainOperator{}} + if err := rows.Scan(&loc.Tiploc, &loc.Name, &loc.Crs, &loc.Operator.Code, &loc.Operator.Name, &loc.Operator.Url); err != nil { + return nil, fmt.Errorf("scanning for locations: %w", err) + } + loc.Canonicalize() + locs = append(locs, loc) + } + if rows.Err() != nil { + return nil, fmt.Errorf("rows error: %w", rows.Err()) + } + + return &webapi.LocationList{Locations: locs}, nil +} + type server struct { dbPool *pgxpool.Pool grpcServer *grpc.Server @@ -368,12 +439,12 @@ func (s *server) handleEventStream(ctx context.Context, rw http.ResponseWriter, } func (s *server) handleHTTP(ctx context.Context, rw http.ResponseWriter, r *http.Request) error { - accept := r.Header.Get("Accept") - if r.ProtoMajor == 2 && strings.HasPrefix(accept, "application/grpc") { + if r.ProtoMajor == 2 && strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") { s.grpcServer.ServeHTTP(rw, r) return nil } - switch accept { + + switch r.Header.Get("Accept") { case "application/json": return s.handleJSON(ctx, rw, r) case "text/event-stream": @@ -451,9 +522,14 @@ func main() { } defer pcpool.Close() + gs := grpc.NewServer() + webapi.RegisterTrainsServiceServer(gs, &grpcServer{ + dbPool: pcpool, + }) + s := &server{ dbPool: pcpool, - grpcServer: nil, + grpcServer: gs, } listen := ":13974" httpServer := &http.Server{ diff --git a/go/trains/cmd/default.nix b/go/trains/cmd/default.nix index 0cabb77c39..059e6ea2a6 100644 --- a/go/trains/cmd/default.nix +++ b/go/trains/cmd/default.nix @@ -11,6 +11,8 @@ rec { rttingest = import ./rttingest args; train2livesplit = import ./train2livesplit args; + grpctest = import ./grpctest args; + bins = { inherit darwiningestd db2web db2fcm train2livesplit; inherit (rttingest) tiploc; diff --git a/go/trains/webapi/summarize/service.go b/go/trains/webapi/summarize/service.go index 148660b4e1..7caf450dee 100644 --- a/go/trains/webapi/summarize/service.go +++ b/go/trains/webapi/summarize/service.go @@ -3,6 +3,7 @@ package summarize import ( "context" "fmt" + "strings" "github.com/jackc/pgx/v4" "hg.lukegb.com/lukegb/depot/go/trains/webapi" @@ -13,8 +14,8 @@ type Querier interface { QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row } -func Service(ctx context.Context, pc Querier, id int) (*webapi.ServiceData, error) { - row := pc.QueryRow(ctx, ` +func serviceWhere(ctx context.Context, pc Querier, limit int, where string, params ...interface{}) ([]*webapi.ServiceData, error) { + rows, err := pc.Query(ctx, ` SELECT ts.id, ts.rid, ts.uid, ts.rsid, ts.headcode, ts.scheduled_start_date::varchar, ts.train_operator, COALESCE(rop.name, ''), COALESCE(rop.url, ''), @@ -33,29 +34,52 @@ FROM LEFT JOIN ref_locations rcloc ON rcloc.tiploc=ts.cancellation_reason_tiploc LEFT JOIN ref_tocs rcop ON rcop.toc=rcloc.toc WHERE - ts.id=$1 -`, id) - - sd := webapi.ServiceData{ - DelayReason: &webapi.DisruptionReason{Location: &webapi.Location{Operator: &webapi.TrainOperator{}}}, - CancelReason: &webapi.DisruptionReason{Location: &webapi.Location{Operator: &webapi.TrainOperator{}}}, - Operator: &webapi.TrainOperator{}, + `+where+fmt.Sprintf(` +ORDER BY ts.id +LIMIT %d`, limit), params...) + if err != nil { + return nil, fmt.Errorf("querying for services: %w", err) } - if err := row.Scan( - &sd.Id, &sd.Rid, &sd.Uid, &sd.Rsid, &sd.Headcode, &sd.ScheduledStartDate, - &sd.Operator.Code, &sd.Operator.Name, &sd.Operator.Url, - &sd.DelayReason.Code, &sd.DelayReason.Text, &sd.DelayReason.Location.Tiploc, &sd.DelayReason.Location.Name, &sd.DelayReason.Location.Crs, &sd.DelayReason.Location.Operator.Code, &sd.DelayReason.Location.Operator.Name, &sd.DelayReason.Location.Operator.Url, &sd.DelayReason.NearLocation, - &sd.CancelReason.Code, &sd.CancelReason.Text, &sd.CancelReason.Location.Tiploc, &sd.CancelReason.Location.Name, &sd.CancelReason.Location.Crs, &sd.CancelReason.Location.Operator.Code, &sd.CancelReason.Location.Operator.Name, &sd.CancelReason.Location.Operator.Url, &sd.CancelReason.NearLocation, - &sd.IsActive, &sd.IsDeleted, &sd.IsCancelled, - ); err != nil { - return nil, fmt.Errorf("reading from train_services for %d: %w", id, err) - } - sd.Canonicalize() + defer rows.Close() + var sds []*webapi.ServiceData + sdMap := make(map[uint64]*webapi.ServiceData) + for rows.Next() { + sd := &webapi.ServiceData{ + DelayReason: &webapi.DisruptionReason{Location: &webapi.Location{Operator: &webapi.TrainOperator{}}}, + CancelReason: &webapi.DisruptionReason{Location: &webapi.Location{Operator: &webapi.TrainOperator{}}}, + Operator: &webapi.TrainOperator{}, + } + if err := rows.Scan( + &sd.Id, &sd.Rid, &sd.Uid, &sd.Rsid, &sd.Headcode, &sd.ScheduledStartDate, + &sd.Operator.Code, &sd.Operator.Name, &sd.Operator.Url, + &sd.DelayReason.Code, &sd.DelayReason.Text, &sd.DelayReason.Location.Tiploc, &sd.DelayReason.Location.Name, &sd.DelayReason.Location.Crs, &sd.DelayReason.Location.Operator.Code, &sd.DelayReason.Location.Operator.Name, &sd.DelayReason.Location.Operator.Url, &sd.DelayReason.NearLocation, + &sd.CancelReason.Code, &sd.CancelReason.Text, &sd.CancelReason.Location.Tiploc, &sd.CancelReason.Location.Name, &sd.CancelReason.Location.Crs, &sd.CancelReason.Location.Operator.Code, &sd.CancelReason.Location.Operator.Name, &sd.CancelReason.Location.Operator.Url, &sd.CancelReason.NearLocation, + &sd.IsActive, &sd.IsDeleted, &sd.IsCancelled, + ); err != nil { + return nil, fmt.Errorf("reading from train_services: %w", err) + } + sd.Canonicalize() + sds = append(sds, sd) + sdMap[sd.Id] = sd + } + if rows.Err() != nil { + return nil, fmt.Errorf("reading rows for services: %w", err) + } + rows.Close() + + if len(sds) == 0 { + return nil, nil + } + + var idsStrs []string + for id := range sdMap { + idsStrs = append(idsStrs, fmt.Sprintf("%d", id)) + } // Now for the locations. - rows, err := pc.Query(ctx, ` + rows, err = pc.Query(ctx, fmt.Sprintf(` SELECT - tl.id, + tl.tsid, tl.id, tl.tiploc, COALESCE(rloc.name, ''), COALESCE(rloc.crs, ''), COALESCE(rloc.toc, ''), COALESCE(rloctoc.name, ''), COALESCE(rloctoc.url, ''), tl.calling_point::varchar, COALESCE(tl.train_length, 0), tl.service_suppressed, tl.schedule_cancelled, @@ -73,15 +97,16 @@ FROM LEFT JOIN ref_locations rfdloc ON rfdloc.tiploc=tl.schedule_false_destination_tiploc LEFT JOIN ref_tocs rfdloctoc ON rfdloctoc.toc=rfdloc.toc WHERE - tl.tsid=$1 AND tl.active_in_schedule + tl.tsid IN (%s) AND tl.active_in_schedule ORDER BY - COALESCE(tl.schedule_working_arrival, tl.schedule_working_pass, tl.schedule_working_departure) -`, id) + tl.tsid, COALESCE(tl.schedule_working_arrival, tl.schedule_working_pass, tl.schedule_working_departure) +`, strings.Join(idsStrs, ","))) if err != nil { - return nil, fmt.Errorf("querying locations for %d: %w", id, err) + return nil, fmt.Errorf("querying locations: %w", err) } defer rows.Close() for rows.Next() { + var tsid uint64 loc := webapi.ServiceLocation{ Location: &webapi.Location{Operator: &webapi.TrainOperator{}}, Platform: &webapi.PlatformData{}, @@ -91,7 +116,7 @@ ORDER BY PassTiming: &webapi.TimingData{PublicScheduled: &webapi.DateTime{}, WorkingScheduled: &webapi.DateTime{}, PublicEstimated: &webapi.DateTime{}, WorkingEstimated: &webapi.DateTime{}, Actual: &webapi.DateTime{}}, } if err := rows.Scan( - &loc.Id, + &tsid, &loc.Id, &loc.Location.Tiploc, &loc.Location.Name, &loc.Location.Crs, &loc.Location.Operator.Code, &loc.Location.Operator.Name, &loc.Location.Operator.Url, &loc.CallingPointType, &loc.TrainLength, &loc.ServiceSuppressed, &loc.Cancelled, @@ -101,14 +126,105 @@ ORDER BY loc.DepartureTiming.PublicScheduled, loc.DepartureTiming.WorkingScheduled, loc.DepartureTiming.PublicEstimated, loc.DepartureTiming.WorkingEstimated, loc.DepartureTiming.Actual, /*loc.PassTiming.PublicScheduled,*/ loc.PassTiming.WorkingScheduled, loc.PassTiming.PublicEstimated, loc.PassTiming.WorkingEstimated, loc.PassTiming.Actual, ); err != nil { - return nil, fmt.Errorf("scanning locations for %d: %w", id, err) + return nil, fmt.Errorf("scanning locations: %w", err) } loc.Canonicalize() + sd := sdMap[tsid] sd.Locations = append(sd.Locations, &loc) } if rows.Err() != nil { - return nil, fmt.Errorf("iterating over locations for %d: %w", id, err) + return nil, fmt.Errorf("iterating over locations: %w", rows.Err()) } - return &sd, nil + return sds, nil +} + +func Service(ctx context.Context, pc Querier, id int) (*webapi.ServiceData, error) { + sds, err := serviceWhere(ctx, pc, 2, "ts.id=$1", id) + if err != nil { + return nil, fmt.Errorf("query for service %d: %w", id, err) + } + + if len(sds) != 1 { + return nil, fmt.Errorf("query for service %d returned %d services: %w", id, len(sds), err) + } + return sds[0], nil +} + +func ListServices(ctx context.Context, pc Querier, req *webapi.ListServicesRequest) ([]*webapi.ServiceData, error) { + var wheres []string + var vars []interface{} + addVar := func(x interface{}) string { + vars = append(vars, x) + return fmt.Sprintf("$%d", len(vars)) + } + addVars := func(xs ...interface{}) string { + xvars := make([]string, len(xs)) + for n, x := range xs { + xvars[n] = addVar(x) + } + return strings.Join(xvars, ",") + } + + if req.GetRid() != "" { + wheres = append(wheres, fmt.Sprintf("ts.rid=%s", addVar(req.GetRid()))) + } + if req.GetUid() != "" { + wheres = append(wheres, fmt.Sprintf("ts.uid=%s", addVar(req.GetUid()))) + } + if req.GetHeadcode() != "" { + wheres = append(wheres, fmt.Sprintf("ts.headcode=%s", addVar(req.GetHeadcode()))) + } + if req.GetScheduledStartDate() != "" { + wheres = append(wheres, fmt.Sprintf("ts.scheduled_start_date::varchar=%s", addVar(req.GetScheduledStartDate()))) + } + if req.GetTrainOperatorCode() != "" { + wheres = append(wheres, fmt.Sprintf("ts.train_operator=%s", addVar(req.GetTrainOperatorCode()))) + } + + addLocation := func(lr *webapi.LocationRequest, passingPointType ...interface{}) error { + if lr == nil { + return nil + } + var subwheres []string + var specified bool + if lr.GetTiploc() != "" { + subwheres = append(subwheres, fmt.Sprintf("tl.tiploc=%s", addVar(lr.GetTiploc()))) + specified = true + } + if lr.GetCrs() != "" { + subwheres = append(subwheres, fmt.Sprintf("rl.crs=%s AND rl.crs IS NOT NULL", addVar(lr.GetCrs()))) + specified = true + } + if !specified { + return fmt.Errorf("must specify either tiploc or crs") + } + if !lr.GetIncludeInPast() { + subwheres = append(subwheres, "COALESCE(tl.actual_departure, tl.actual_pass, tl.estimated_departure, tl.estimated_pass, tl.working_estimated_departure, tl.working_estimated_pass) > CURRENT_TIMESTAMP") + } + wheres = append(wheres, fmt.Sprintf("ts.id IN (SELECT tl.tsid FROM train_locations tl LEFT JOIN ref_locations rl ON rl.tiploc=tl.tiploc WHERE tl.calling_point IN (%s) AND %s)", addVars(passingPointType...), strings.Join(subwheres, " AND "))) + return nil + } + if err := addLocation(req.GetOrigin(), "OR"); err != nil { + return nil, fmt.Errorf("origin: %w", err) + } + if err := addLocation(req.GetDestination(), "DT"); err != nil { + return nil, fmt.Errorf("destination: %w", err) + } + if err := addLocation(req.GetCallingAt(), "OR", "DT", "IP"); err != nil { + return nil, fmt.Errorf("calling_at: %w", err) + } + if err := addLocation(req.GetPassing(), "OR", "DT", "IP", "PP", "OPIP", "OPOR", "OPDT"); err != nil { + return nil, fmt.Errorf("passing: %w", err) + } + + if !req.GetIncludeDeleted() { + wheres = append(wheres, "NOT ts.deleted") + } + limit := req.GetMaxResults() + if limit == 0 || limit > 100 { + limit = 100 + } + + return serviceWhere(ctx, pc, int(limit), strings.Join(wheres, " AND "), vars...) } diff --git a/go/trains/webapi/webapi.proto b/go/trains/webapi/webapi.proto index d593236ab0..85d5cbd4bc 100644 --- a/go/trains/webapi/webapi.proto +++ b/go/trains/webapi/webapi.proto @@ -36,8 +36,17 @@ message ServiceList { repeated ServiceData services = 1; } +message ListLocationsRequest { + bool include_without_crs = 1; +} + +message LocationList { + repeated Location locations = 1; +} + service TrainsService { rpc SummarizeService(SummarizeServiceRequest) returns (ServiceData) {} rpc ListServices(ListServicesRequest) returns (ServiceList) {} + rpc ListLocations(ListLocationsRequest) returns (LocationList) {} } diff --git a/ops/nixos/totoro/default.nix b/ops/nixos/totoro/default.nix index 15453a5082..8d01683f36 100644 --- a/ops/nixos/totoro/default.nix +++ b/ops/nixos/totoro/default.nix @@ -102,7 +102,7 @@ in { }; users.users.lukegb = { packages = with depot.pkgs; [ irssi ]; - extraGroups = lib.mkAfter [ "libvirtd" ]; + extraGroups = lib.mkAfter [ "libvirtd" "acme" ]; }; users.users.pancake = { isSystemUser = true; @@ -210,6 +210,11 @@ in { systemctl reload nginx ''; }; + certs."trains.lukegb.com" = { + domain = "trains.lukegb.com"; + dnsProvider = "cloudflare"; + credentialsFile = secrets.cloudflareCredentials; + }; }; services.prometheus = {