// SPDX-FileCopyrightText: 2020 Luke Granger-Brown // // SPDX-License-Identifier: Apache-2.0 package main import ( "context" "encoding/json" "flag" "fmt" "log" "net/http" "net/url" "path" "sort" "strconv" "strings" "time" "github.com/google/safehtml" "github.com/google/safehtml/template" "github.com/google/safehtml/uncheckedconversions" "github.com/gorilla/mux" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4/pgxpool" pomerium "github.com/pomerium/sdk-go" "gocloud.dev/blob" "gocloud.dev/gcerrors" _ "gocloud.dev/blob/s3blob" ) var ( listen = flag.String("listen", ":8080", "Listen address") databaseURL = flag.String("database_url", "", "Database URL") userMapping = userMap{} localDisableAuth = flag.String("local_auth_override_user", "", "Disable authn/authz - used for dev - set to username") mediaURL = flag.String("media_url", "s3://public-lukegb-twitterchiver?endpoint=objdump.zxcvbnm.ninja®ion=london", "Blob URL") funcMap = template.FuncMap{ "rewriteURL": rewriteURL, "bestVariant": bestVariant, } indexTmpl = template.Must(template.New("index.html").Funcs(funcMap).ParseFiles("templates/index.html")) tweetsTmpl = template.Must(template.New("tweets.html").Funcs(funcMap).ParseFiles("templates/tweets.html")) ) const ( redirectExpiry = 5 * time.Minute ) func init() { flag.Var(userMapping, "user_to_twitter", "Space-separated list of :") } type userMap map[string][]string func (um userMap) String() string { var bits []string for u, ts := range um { ts2 := make([]string, len(ts)) copy(ts2, ts) sort.Strings(ts2) bits = append(bits, fmt.Sprintf("%s:%s", u, strings.Join(ts2, ","))) } sort.Strings(bits) return strings.Join(bits, " ") } func (um userMap) Set(v string) error { bits := strings.Split(v, " ") for _, b := range bits { utsPair := strings.Split(b, ":") if len(utsPair) != 2 { return fmt.Errorf("%v is not a : string", b) } u := utsPair[0] ts := utsPair[1] um[u] = append(um[u], strings.Split(ts, ",")...) } return nil } func rewriteURL(u string) string { p, err := url.Parse(u) if err != nil { return "" } ps := strings.TrimLeft(p.Path, "/") return path.Join("/media", p.Host, ps) } func bestVariant(vs []interface{}) interface{} { var bestBitrate float64 = -1 var bestVariant interface{} for _, v := range vs { v := v.(map[string]interface{}) if v["content_type"].(string) == "application/x-mpegURL" { // m3u8 isn't very interesting since we'd have to parse it to get a playlist. continue } br := v["bitrate"].(float64) if br > bestBitrate { bestBitrate = br bestVariant = v } } return bestVariant } type mapDatastore map[interface{}]interface{} func (ds mapDatastore) Get(key interface{}) (value interface{}, ok bool) { value, ok = ds[key] return value, ok } func (ds mapDatastore) Add(key, value interface{}) { ds[key] = value } func main() { flag.Parse() ctx := context.Background() mediaBucket, err := blob.OpenBucket(ctx, *mediaURL) if err != nil { log.Fatalf("blob.OpenBucket: %v", err) } pool, err := pgxpool.Connect(ctx, *databaseURL) if err != nil { log.Fatalf("pgxpool.Connect: %v", err) } defer pool.Close() r := mux.NewRouter() r.HandleFunc("/healthz", func(rw http.ResponseWriter, r *http.Request) { rw.Header().Set("Content-Type", "text/plain") fmt.Fprintf(rw, "ok") }) var authR *mux.Router var v *pomerium.Verifier if *localDisableAuth == "" { authR = r.PathPrefix("/").Subrouter() var err error v, err = pomerium.New(&pomerium.Options{ JWKSEndpoint: "https://auth.int.lukegb.com/.well-known/pomerium/jwks.json", Datastore: mapDatastore{}, }) if err != nil { log.Fatalf("pomerium.New: %v", err) } authR.Use(pomerium.AddIdentityToRequest(v)) log.Println("configured pomerium middleware") } else { authR = r } userFromContext := func(ctx context.Context) (string, error) { if *localDisableAuth != "" { return *localDisableAuth, nil } identity, err := pomerium.FromContext(ctx) if err != nil { return "", err } return identity.Email, nil } writeError := func(rw http.ResponseWriter, status int, wrap string, err error) { log.Printf("Error in HTTP handler: %v: %v", wrap, err) rw.WriteHeader(status) fmt.Fprintf(rw, "

Oops. Something went wrong.

") fmt.Fprintf(rw, "

%s

", wrap) } authR.PathPrefix("/media/{filepath:.*}").HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() vars := mux.Vars(r) _, err := userFromContext(ctx) if err != nil { http.Error(rw, fmt.Errorf("userFromContext: %w", err).Error(), http.StatusInternalServerError) return } u, err := mediaBucket.SignedURL(ctx, vars["filepath"], &blob.SignedURLOptions{ Expiry: redirectExpiry, }) switch gcerrors.Code(err) { case gcerrors.NotFound: // This is unlikely to get returned. http.Error(rw, "not found", http.StatusNotFound) return case gcerrors.OK: http.Redirect(rw, r, u, http.StatusFound) return default: log.Printf("media SignedURL: %v", vars["filepath"], err) } http.Error(rw, err.Error(), http.StatusInternalServerError) }) authR.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() user, err := userFromContext(ctx) if err != nil { http.Error(rw, fmt.Errorf("userFromContext: %w", err).Error(), http.StatusInternalServerError) return } twitterAccounts := userMapping[user] var allTweets, timelineTweets, unfetchedMedia, unfetchedRelated int if err := pool.QueryRow(ctx, ` SELECT (SELECT COUNT(id) FROM tweets) all_tweets, (SELECT COUNT(DISTINCT tweetid) FROM user_accounts_tweets WHERE on_timeline) timeline_tweets, (SELECT COUNT(id) FROM tweets WHERE NOT fetched_media) not_fetched_media, (SELECT COUNT(id) FROM tweets WHERE NOT fetched_related_tweets) not_fetched_related `).Scan(&allTweets, &timelineTweets, &unfetchedMedia, &unfetchedRelated); err != nil { writeError(rw, http.StatusInternalServerError, "querying stats", err) return } rows, err := pool.Query(ctx, ` SELECT ua.username, COUNT(uat.tweetid) tweet_count, (SELECT CAST(object->>'created_at' AS timestamp with time zone) FROM tweets WHERE id=MAX(uat.tweetid)) latest_tweet FROM user_accounts ua LEFT JOIN user_accounts_tweets uat ON uat.userid=ua.userid WHERE 1=1 AND ua.username = ANY($1::text[]) AND uat.on_timeline GROUP BY 1 ORDER BY 1 `, twitterAccounts) if err != nil { writeError(rw, http.StatusInternalServerError, "querying database", err) return } defer rows.Close() type twitterData struct { Username string TweetCount int LatestTweet time.Time } var tds []twitterData for rows.Next() { var td twitterData if err := rows.Scan(&td.Username, &td.TweetCount, &td.LatestTweet); err != nil { writeError(rw, http.StatusInternalServerError, "reading from database", err) return } tds = append(tds, td) } rows.Close() indexTmpl.Execute(rw, struct { Username string TwitterAccounts []twitterData TotalTweets int TimelineTweets int UnfetchedMedia int UnfetchedRelated int }{ Username: user, TwitterAccounts: tds, TotalTweets: allTweets, TimelineTweets: timelineTweets, UnfetchedMedia: unfetchedMedia, UnfetchedRelated: unfetchedRelated, }) }) isAllowedToSee := func(ctx context.Context, twitterUser string) bool { user, err := userFromContext(ctx) if err != nil { return false } twitterAccounts := userMapping[user] for _, a := range twitterAccounts { if a == twitterUser { return true } } return false } toInt := func(s string, def int) int { n, err := strconv.ParseInt(s, 10, 0) if err != nil { return def } return int(n) } clamp := func(min, n, max int) int { if n < min { return min } else if n > max { return max } return n } authR.HandleFunc("/view/{twitterUser}", func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() vars := mux.Vars(r) twitterUser := vars["twitterUser"] user, err := userFromContext(ctx) if err != nil { http.Error(rw, fmt.Errorf("userFromContext: %w", err).Error(), http.StatusInternalServerError) return } if !isAllowedToSee(ctx, twitterUser) { writeError(rw, http.StatusNotFound, "no such twitter user being archived", fmt.Errorf("user %q attempted to access %q", user, twitterUser)) return } q := r.URL.Query() pageSize := clamp(1, toInt(q.Get("page_size"), 20), 200) startFrom := toInt(q.Get("start_from"), 0) query := q.Get("q") rows, err := pool.Query(ctx, ` SELECT t.id, t.text, t.object, rt.object retweet_object, qt.object quote_object, CAST(COALESCE(t.object->'retweeted_status'->>'created_at', t.object->>'created_at') AS timestamp with time zone) created_at, CAST(qt.object->>'created_at' AS timestamp with time zone) quoted_created_at FROM user_accounts_tweets uat INNER JOIN user_accounts ua ON ua.userid=uat.userid INNER JOIN tweets t ON t.id=uat.tweetid LEFT JOIN tweets rt ON rt.id=CAST(t.object->'retweeted_status'->>'id' AS BIGINT) LEFT JOIN tweets qt ON qt.id=CAST(t.object->>'quoted_status_id' AS BIGINT) WHERE 1=1 AND ua.username=$1 AND uat.on_timeline AND ($3::bigint=0 OR t.id <= $3::bigint) AND ($4='' OR ($4<>'' AND (to_tsvector('english', t.text) @@ to_tsquery('english', $4) OR to_tsvector('english', t.object->'retweeted_status'->>'full_text') @@ to_tsquery('english', $4) OR t.object->'user'->>'screen_name'=$4))) ORDER BY t.id DESC LIMIT $2 `, twitterUser, pageSize+1, startFrom, query) if err != nil { writeError(rw, http.StatusInternalServerError, "querying database", err) return } defer rows.Close() type tweet struct { ID int64 Text string CreatedAt time.Time CreatedAtFriendly string Object interface{} RetweetedObject interface{} QuotedObject interface{} QuotedCreatedAt time.Time QuotedCreatedAtFriendly string } type twitterData struct { TwitterUsername string Query string Tweets []tweet NextTweetID *int64 FormatTweetText func(string, interface{}) safehtml.HTML } pullIndices := func(m map[string]interface{}) [2]int { midx := m["indices"].([]interface{}) midx0 := int(midx[0].(float64)) midx1 := int(midx[1].(float64)) return [2]int{midx0, midx1} } td := twitterData{ TwitterUsername: twitterUser, Query: query, FormatTweetText: func(t string, tw interface{}) safehtml.HTML { ltRep := string([]rune{0xe000}) gtRep := string([]rune{0xe001}) t = strings.ReplaceAll(t, "<", ltRep) t = strings.ReplaceAll(t, ">", gtRep) type span struct { span [2]int whatDo string // remove, link linkTo string // link only linkText string // link only } // Delete native media and add links. var spans []span obj := tw.(map[string]interface{}) if ee, ok := obj["extended_entities"].(map[string]interface{}); ok { ems := ee["media"].([]interface{}) for _, emi := range ems { em := emi.(map[string]interface{}) spans = append(spans, span{ span: pullIndices(em), whatDo: "remove", }) } } if es, ok := obj["entities"].(map[string]interface{}); ok { if hts, ok := es["hashtags"].([]interface{}); ok { for _, hti := range hts { ht := hti.(map[string]interface{}) htt := ht["text"].(string) spans = append(spans, span{ span: pullIndices(ht), whatDo: "link", linkTo: fmt.Sprintf("https://twitter.com/hashtag/%s", htt), }) } } if urls, ok := es["urls"].([]interface{}); ok { for _, urli := range urls { url := urli.(map[string]interface{}) urldisp := url["display_url"].(string) urlexp := url["expanded_url"].(string) spans = append(spans, span{ span: pullIndices(url), whatDo: "link", linkTo: urlexp, linkText: urldisp, }) } } if mentions, ok := es["user_mentions"].([]interface{}); ok { for _, mentioni := range mentions { mention := mentioni.(map[string]interface{}) mentionuser := mention["screen_name"].(string) spans = append(spans, span{ span: pullIndices(mention), whatDo: "link", linkTo: fmt.Sprintf("https://twitter.com/%s", mentionuser), }) } } if symbols, ok := es["symbols"].([]interface{}); ok { for _, symboli := range symbols { symbol := symboli.(map[string]interface{}) symbolname := symbol["text"].(string) spans = append(spans, span{ span: pullIndices(symbol), whatDo: "link", linkTo: fmt.Sprintf("?q=$%s", symbolname), }) } } } // Sort removeSpans from the end to the beginning. sort.Slice(spans, func(a, b int) bool { return spans[a].span[0] > spans[b].span[0] }) // Expand overlapping remove spans. newSpans := make([]span, 0, len(spans)) for i := 0; i < len(spans)-1; i++ { span := spans[i] prevSpan := spans[i+1] if prevSpan.span[0] <= span.span[0] && prevSpan.span[1] >= span.span[1] { // Spans overlap. if span.whatDo != "remove" || prevSpan.whatDo != "remove" { log.Printf("found overlapping non-remove spans!") } if span.span[1] > prevSpan.span[1] { prevSpan.span[1] = span.span[1] } continue } newSpans = append(newSpans, span) } if len(spans) > 0 { newSpans = append(newSpans, spans[len(spans)-1]) } spans = newSpans runed := []rune(t) for _, span := range spans { switch span.whatDo { case "remove": // Delete text from span[0] to span[1]. runed = append(runed[:span.span[0]], runed[span.span[1]:]...) case "link": // Add a link. var text []rune if span.linkText == "" { text = runed[span.span[0]:span.span[1]] } else { text = []rune(span.linkText) } runedBits := [][]rune{ runed[:span.span[0]], []rune(""), []rune(text), []rune(""), runed[span.span[1]:], } finalLen := 0 for _, s := range runedBits { finalLen += len(s) } runed = make([]rune, finalLen) p := 0 for _, s := range runedBits { p += copy(runed[p:], s) } default: log.Printf("unknown span operation %v", span.whatDo) } } t = string(runed) // HTML escape any <> t = strings.ReplaceAll(t, ltRep, "<") t = strings.ReplaceAll(t, gtRep, ">") return uncheckedconversions.HTMLFromStringKnownToSatisfyTypeContract(t) }, } now := time.Now() for rows.Next() { var t tweet var o, ro, qo pgtype.JSONB var qca pgtype.Timestamptz if err := rows.Scan(&t.ID, &t.Text, &o, &ro, &qo, &t.CreatedAt, &qca); err != nil { writeError(rw, http.StatusInternalServerError, "reading from database", err) return } if o.Status == pgtype.Present { var oo interface{} if err := json.Unmarshal(o.Bytes, &oo); err != nil { writeError(rw, http.StatusInternalServerError, "parsing JSON from database (object)", err) return } t.Object = oo t.CreatedAtFriendly = agoFriendly(now, t.CreatedAt) } if ro.Status == pgtype.Present { var oo interface{} if err := json.Unmarshal(ro.Bytes, &oo); err != nil { writeError(rw, http.StatusInternalServerError, "parsing JSON from database (retweeted object)", err) return } t.RetweetedObject = oo } if qo.Status == pgtype.Present { var oo interface{} if err := json.Unmarshal(qo.Bytes, &oo); err != nil { writeError(rw, http.StatusInternalServerError, "parsing JSON from database (quoted object)", err) return } t.QuotedObject = oo } if qca.Status == pgtype.Present { t.QuotedCreatedAt = qca.Time t.QuotedCreatedAtFriendly = agoFriendly(now, qca.Time) } td.Tweets = append(td.Tweets, t) } rows.Close() if len(td.Tweets) > pageSize { td.NextTweetID = &td.Tweets[pageSize].ID td.Tweets = td.Tweets[:pageSize] } if err := tweetsTmpl.Execute(rw, td); err != nil { log.Printf("tweets: executing template: %v", err) } }) log.Printf("now listening on %s", *listen) log.Print(http.ListenAndServe(*listen, r)) } func agoFriendly(now, at time.Time) string { ago := now.Sub(at) switch { case at.Year() != now.Year(): return at.Format("Jan 2, 2006") case at.YearDay() != now.YearDay(): return at.Format("Jan 2") case ago.Hours() >= 1.0: return fmt.Sprintf("%dh", int(ago.Hours())) case ago.Minutes() >= 1.0: return fmt.Sprintf("%dm", int(ago.Minutes())) case ago.Seconds() >= 0.0: return fmt.Sprintf("%ds", int(ago.Seconds())) default: return fmt.Sprintf("in %ds", -int(ago.Seconds())) } }