go/trains: go! trains!

This commit is contained in:
Luke Granger-Brown 2021-11-23 12:32:01 +00:00
parent bde764c537
commit 1eda43af34
5 changed files with 242 additions and 34 deletions

View file

@ -18,6 +18,7 @@ import (
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/protobuf/encoding/protojson"
"hg.lukegb.com/lukegb/depot/go/trains/webapi"
"hg.lukegb.com/lukegb/depot/go/trains/webapi/summarize"
)
@ -73,6 +74,76 @@ func (he httpError) Error() string {
return he.err.Error()
}
type grpcServer struct {
webapi.UnimplementedTrainsServiceServer
dbPool *pgxpool.Pool
}
func (s *grpcServer) SummarizeService(ctx context.Context, req *webapi.SummarizeServiceRequest) (*webapi.ServiceData, error) {
conn, err := s.dbPool.Acquire(ctx)
if err != nil {
return nil, fmt.Errorf("dbPool.Acquire: %w", err)
}
defer conn.Release()
return summarize.Service(ctx, conn, int(req.GetId()))
}
func (s *grpcServer) ListServices(ctx context.Context, req *webapi.ListServicesRequest) (*webapi.ServiceList, error) {
conn, err := s.dbPool.Acquire(ctx)
if err != nil {
return nil, fmt.Errorf("dbPool.Acquire: %w", err)
}
defer conn.Release()
sds, err := summarize.ListServices(ctx, conn, req)
if err != nil {
return nil, fmt.Errorf("ListServices: %w", err)
}
return &webapi.ServiceList{Services: sds}, nil
}
func (s *grpcServer) ListLocations(ctx context.Context, req *webapi.ListLocationsRequest) (*webapi.LocationList, error) {
conn, err := s.dbPool.Acquire(ctx)
if err != nil {
return nil, fmt.Errorf("dbPool.Acquire: %w", err)
}
defer conn.Release()
var locs []*webapi.Location
where := "TRUE"
if !req.GetIncludeWithoutCrs() {
where = "rl.crs IS NOT NULL AND rl.crs <> ''"
}
rows, err := conn.Query(ctx, fmt.Sprintf(`
SELECT
rl.tiploc, rl.name, rl.crs,
rltoc.toc, rltoc.name, rltoc.url
FROM
ref_locations rl
INNER JOIN
ref_tocs rltoc ON rltoc.toc = rl.toc
WHERE %s`, where))
if err != nil {
return nil, fmt.Errorf("querying for locations: %w", err)
}
for rows.Next() {
loc := &webapi.Location{Operator: &webapi.TrainOperator{}}
if err := rows.Scan(&loc.Tiploc, &loc.Name, &loc.Crs, &loc.Operator.Code, &loc.Operator.Name, &loc.Operator.Url); err != nil {
return nil, fmt.Errorf("scanning for locations: %w", err)
}
loc.Canonicalize()
locs = append(locs, loc)
}
if rows.Err() != nil {
return nil, fmt.Errorf("rows error: %w", rows.Err())
}
return &webapi.LocationList{Locations: locs}, nil
}
type server struct {
dbPool *pgxpool.Pool
grpcServer *grpc.Server
@ -368,12 +439,12 @@ func (s *server) handleEventStream(ctx context.Context, rw http.ResponseWriter,
}
func (s *server) handleHTTP(ctx context.Context, rw http.ResponseWriter, r *http.Request) error {
accept := r.Header.Get("Accept")
if r.ProtoMajor == 2 && strings.HasPrefix(accept, "application/grpc") {
if r.ProtoMajor == 2 && strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") {
s.grpcServer.ServeHTTP(rw, r)
return nil
}
switch accept {
switch r.Header.Get("Accept") {
case "application/json":
return s.handleJSON(ctx, rw, r)
case "text/event-stream":
@ -451,9 +522,14 @@ func main() {
}
defer pcpool.Close()
gs := grpc.NewServer()
webapi.RegisterTrainsServiceServer(gs, &grpcServer{
dbPool: pcpool,
})
s := &server{
dbPool: pcpool,
grpcServer: nil,
grpcServer: gs,
}
listen := ":13974"
httpServer := &http.Server{

View file

@ -11,6 +11,8 @@ rec {
rttingest = import ./rttingest args;
train2livesplit = import ./train2livesplit args;
grpctest = import ./grpctest args;
bins = {
inherit darwiningestd db2web db2fcm train2livesplit;
inherit (rttingest) tiploc;

View file

@ -3,6 +3,7 @@ package summarize
import (
"context"
"fmt"
"strings"
"github.com/jackc/pgx/v4"
"hg.lukegb.com/lukegb/depot/go/trains/webapi"
@ -13,8 +14,8 @@ type Querier interface {
QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
}
func Service(ctx context.Context, pc Querier, id int) (*webapi.ServiceData, error) {
row := pc.QueryRow(ctx, `
func serviceWhere(ctx context.Context, pc Querier, limit int, where string, params ...interface{}) ([]*webapi.ServiceData, error) {
rows, err := pc.Query(ctx, `
SELECT
ts.id, ts.rid, ts.uid, ts.rsid, ts.headcode, ts.scheduled_start_date::varchar,
ts.train_operator, COALESCE(rop.name, ''), COALESCE(rop.url, ''),
@ -33,29 +34,52 @@ FROM
LEFT JOIN ref_locations rcloc ON rcloc.tiploc=ts.cancellation_reason_tiploc
LEFT JOIN ref_tocs rcop ON rcop.toc=rcloc.toc
WHERE
ts.id=$1
`, id)
`+where+fmt.Sprintf(`
ORDER BY ts.id
LIMIT %d`, limit), params...)
if err != nil {
return nil, fmt.Errorf("querying for services: %w", err)
}
defer rows.Close()
var sds []*webapi.ServiceData
sdMap := make(map[uint64]*webapi.ServiceData)
sd := webapi.ServiceData{
for rows.Next() {
sd := &webapi.ServiceData{
DelayReason: &webapi.DisruptionReason{Location: &webapi.Location{Operator: &webapi.TrainOperator{}}},
CancelReason: &webapi.DisruptionReason{Location: &webapi.Location{Operator: &webapi.TrainOperator{}}},
Operator: &webapi.TrainOperator{},
}
if err := row.Scan(
if err := rows.Scan(
&sd.Id, &sd.Rid, &sd.Uid, &sd.Rsid, &sd.Headcode, &sd.ScheduledStartDate,
&sd.Operator.Code, &sd.Operator.Name, &sd.Operator.Url,
&sd.DelayReason.Code, &sd.DelayReason.Text, &sd.DelayReason.Location.Tiploc, &sd.DelayReason.Location.Name, &sd.DelayReason.Location.Crs, &sd.DelayReason.Location.Operator.Code, &sd.DelayReason.Location.Operator.Name, &sd.DelayReason.Location.Operator.Url, &sd.DelayReason.NearLocation,
&sd.CancelReason.Code, &sd.CancelReason.Text, &sd.CancelReason.Location.Tiploc, &sd.CancelReason.Location.Name, &sd.CancelReason.Location.Crs, &sd.CancelReason.Location.Operator.Code, &sd.CancelReason.Location.Operator.Name, &sd.CancelReason.Location.Operator.Url, &sd.CancelReason.NearLocation,
&sd.IsActive, &sd.IsDeleted, &sd.IsCancelled,
); err != nil {
return nil, fmt.Errorf("reading from train_services for %d: %w", id, err)
return nil, fmt.Errorf("reading from train_services: %w", err)
}
sd.Canonicalize()
sds = append(sds, sd)
sdMap[sd.Id] = sd
}
if rows.Err() != nil {
return nil, fmt.Errorf("reading rows for services: %w", err)
}
rows.Close()
if len(sds) == 0 {
return nil, nil
}
var idsStrs []string
for id := range sdMap {
idsStrs = append(idsStrs, fmt.Sprintf("%d", id))
}
// Now for the locations.
rows, err := pc.Query(ctx, `
rows, err = pc.Query(ctx, fmt.Sprintf(`
SELECT
tl.id,
tl.tsid, tl.id,
tl.tiploc, COALESCE(rloc.name, ''), COALESCE(rloc.crs, ''), COALESCE(rloc.toc, ''), COALESCE(rloctoc.name, ''), COALESCE(rloctoc.url, ''),
tl.calling_point::varchar, COALESCE(tl.train_length, 0), tl.service_suppressed,
tl.schedule_cancelled,
@ -73,15 +97,16 @@ FROM
LEFT JOIN ref_locations rfdloc ON rfdloc.tiploc=tl.schedule_false_destination_tiploc
LEFT JOIN ref_tocs rfdloctoc ON rfdloctoc.toc=rfdloc.toc
WHERE
tl.tsid=$1 AND tl.active_in_schedule
tl.tsid IN (%s) AND tl.active_in_schedule
ORDER BY
COALESCE(tl.schedule_working_arrival, tl.schedule_working_pass, tl.schedule_working_departure)
`, id)
tl.tsid, COALESCE(tl.schedule_working_arrival, tl.schedule_working_pass, tl.schedule_working_departure)
`, strings.Join(idsStrs, ",")))
if err != nil {
return nil, fmt.Errorf("querying locations for %d: %w", id, err)
return nil, fmt.Errorf("querying locations: %w", err)
}
defer rows.Close()
for rows.Next() {
var tsid uint64
loc := webapi.ServiceLocation{
Location: &webapi.Location{Operator: &webapi.TrainOperator{}},
Platform: &webapi.PlatformData{},
@ -91,7 +116,7 @@ ORDER BY
PassTiming: &webapi.TimingData{PublicScheduled: &webapi.DateTime{}, WorkingScheduled: &webapi.DateTime{}, PublicEstimated: &webapi.DateTime{}, WorkingEstimated: &webapi.DateTime{}, Actual: &webapi.DateTime{}},
}
if err := rows.Scan(
&loc.Id,
&tsid, &loc.Id,
&loc.Location.Tiploc, &loc.Location.Name, &loc.Location.Crs, &loc.Location.Operator.Code, &loc.Location.Operator.Name, &loc.Location.Operator.Url,
&loc.CallingPointType, &loc.TrainLength, &loc.ServiceSuppressed,
&loc.Cancelled,
@ -101,14 +126,105 @@ ORDER BY
loc.DepartureTiming.PublicScheduled, loc.DepartureTiming.WorkingScheduled, loc.DepartureTiming.PublicEstimated, loc.DepartureTiming.WorkingEstimated, loc.DepartureTiming.Actual,
/*loc.PassTiming.PublicScheduled,*/ loc.PassTiming.WorkingScheduled, loc.PassTiming.PublicEstimated, loc.PassTiming.WorkingEstimated, loc.PassTiming.Actual,
); err != nil {
return nil, fmt.Errorf("scanning locations for %d: %w", id, err)
return nil, fmt.Errorf("scanning locations: %w", err)
}
loc.Canonicalize()
sd := sdMap[tsid]
sd.Locations = append(sd.Locations, &loc)
}
if rows.Err() != nil {
return nil, fmt.Errorf("iterating over locations for %d: %w", id, err)
return nil, fmt.Errorf("iterating over locations: %w", rows.Err())
}
return &sd, nil
return sds, nil
}
func Service(ctx context.Context, pc Querier, id int) (*webapi.ServiceData, error) {
sds, err := serviceWhere(ctx, pc, 2, "ts.id=$1", id)
if err != nil {
return nil, fmt.Errorf("query for service %d: %w", id, err)
}
if len(sds) != 1 {
return nil, fmt.Errorf("query for service %d returned %d services: %w", id, len(sds), err)
}
return sds[0], nil
}
func ListServices(ctx context.Context, pc Querier, req *webapi.ListServicesRequest) ([]*webapi.ServiceData, error) {
var wheres []string
var vars []interface{}
addVar := func(x interface{}) string {
vars = append(vars, x)
return fmt.Sprintf("$%d", len(vars))
}
addVars := func(xs ...interface{}) string {
xvars := make([]string, len(xs))
for n, x := range xs {
xvars[n] = addVar(x)
}
return strings.Join(xvars, ",")
}
if req.GetRid() != "" {
wheres = append(wheres, fmt.Sprintf("ts.rid=%s", addVar(req.GetRid())))
}
if req.GetUid() != "" {
wheres = append(wheres, fmt.Sprintf("ts.uid=%s", addVar(req.GetUid())))
}
if req.GetHeadcode() != "" {
wheres = append(wheres, fmt.Sprintf("ts.headcode=%s", addVar(req.GetHeadcode())))
}
if req.GetScheduledStartDate() != "" {
wheres = append(wheres, fmt.Sprintf("ts.scheduled_start_date::varchar=%s", addVar(req.GetScheduledStartDate())))
}
if req.GetTrainOperatorCode() != "" {
wheres = append(wheres, fmt.Sprintf("ts.train_operator=%s", addVar(req.GetTrainOperatorCode())))
}
addLocation := func(lr *webapi.LocationRequest, passingPointType ...interface{}) error {
if lr == nil {
return nil
}
var subwheres []string
var specified bool
if lr.GetTiploc() != "" {
subwheres = append(subwheres, fmt.Sprintf("tl.tiploc=%s", addVar(lr.GetTiploc())))
specified = true
}
if lr.GetCrs() != "" {
subwheres = append(subwheres, fmt.Sprintf("rl.crs=%s AND rl.crs IS NOT NULL", addVar(lr.GetCrs())))
specified = true
}
if !specified {
return fmt.Errorf("must specify either tiploc or crs")
}
if !lr.GetIncludeInPast() {
subwheres = append(subwheres, "COALESCE(tl.actual_departure, tl.actual_pass, tl.estimated_departure, tl.estimated_pass, tl.working_estimated_departure, tl.working_estimated_pass) > CURRENT_TIMESTAMP")
}
wheres = append(wheres, fmt.Sprintf("ts.id IN (SELECT tl.tsid FROM train_locations tl LEFT JOIN ref_locations rl ON rl.tiploc=tl.tiploc WHERE tl.calling_point IN (%s) AND %s)", addVars(passingPointType...), strings.Join(subwheres, " AND ")))
return nil
}
if err := addLocation(req.GetOrigin(), "OR"); err != nil {
return nil, fmt.Errorf("origin: %w", err)
}
if err := addLocation(req.GetDestination(), "DT"); err != nil {
return nil, fmt.Errorf("destination: %w", err)
}
if err := addLocation(req.GetCallingAt(), "OR", "DT", "IP"); err != nil {
return nil, fmt.Errorf("calling_at: %w", err)
}
if err := addLocation(req.GetPassing(), "OR", "DT", "IP", "PP", "OPIP", "OPOR", "OPDT"); err != nil {
return nil, fmt.Errorf("passing: %w", err)
}
if !req.GetIncludeDeleted() {
wheres = append(wheres, "NOT ts.deleted")
}
limit := req.GetMaxResults()
if limit == 0 || limit > 100 {
limit = 100
}
return serviceWhere(ctx, pc, int(limit), strings.Join(wheres, " AND "), vars...)
}

View file

@ -36,8 +36,17 @@ message ServiceList {
repeated ServiceData services = 1;
}
message ListLocationsRequest {
bool include_without_crs = 1;
}
message LocationList {
repeated Location locations = 1;
}
service TrainsService {
rpc SummarizeService(SummarizeServiceRequest) returns (ServiceData) {}
rpc ListServices(ListServicesRequest) returns (ServiceList) {}
rpc ListLocations(ListLocationsRequest) returns (LocationList) {}
}

View file

@ -102,7 +102,7 @@ in {
};
users.users.lukegb = {
packages = with depot.pkgs; [ irssi ];
extraGroups = lib.mkAfter [ "libvirtd" ];
extraGroups = lib.mkAfter [ "libvirtd" "acme" ];
};
users.users.pancake = {
isSystemUser = true;
@ -210,6 +210,11 @@ in {
systemctl reload nginx
'';
};
certs."trains.lukegb.com" = {
domain = "trains.lukegb.com";
dnsProvider = "cloudflare";
credentialsFile = secrets.cloudflareCredentials;
};
};
services.prometheus = {