564 lines
16 KiB
Go
564 lines
16 KiB
Go
// SPDX-FileCopyrightText: 2020 Luke Granger-Brown <depot@lukegb.com>
|
|
//
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/safehtml"
|
|
"github.com/google/safehtml/template"
|
|
"github.com/google/safehtml/uncheckedconversions"
|
|
"github.com/gorilla/mux"
|
|
"github.com/jackc/pgtype"
|
|
"github.com/jackc/pgx/v4/pgxpool"
|
|
pomerium "github.com/pomerium/sdk-go"
|
|
)
|
|
|
|
var (
|
|
databaseURL = flag.String("database_url", "", "Database URL")
|
|
userMapping = userMap{}
|
|
localDisableAuth = flag.String("local_auth_override_user", "", "Disable authn/authz - used for dev - set to username")
|
|
mediaDirectory = flag.String("media_dir", "/media", "Archived media directory")
|
|
|
|
funcMap = template.FuncMap{
|
|
"rewriteURL": rewriteURL,
|
|
"bestVariant": bestVariant,
|
|
}
|
|
indexTmpl = template.Must(template.New("index.html").Funcs(funcMap).ParseFiles("templates/index.html"))
|
|
tweetsTmpl = template.Must(template.New("tweets.html").Funcs(funcMap).ParseFiles("templates/tweets.html"))
|
|
)
|
|
|
|
func init() {
|
|
flag.Var(userMapping, "user_to_twitter", "Space-separated list of <username>:<comma-separated Twitter usernames>")
|
|
}
|
|
|
|
type userMap map[string][]string
|
|
|
|
func (um userMap) String() string {
|
|
var bits []string
|
|
for u, ts := range um {
|
|
ts2 := make([]string, len(ts))
|
|
copy(ts2, ts)
|
|
sort.Strings(ts2)
|
|
bits = append(bits, fmt.Sprintf("%s:%s", u, strings.Join(ts2, ",")))
|
|
}
|
|
sort.Strings(bits)
|
|
return strings.Join(bits, " ")
|
|
}
|
|
|
|
func (um userMap) Set(v string) error {
|
|
bits := strings.Split(v, " ")
|
|
for _, b := range bits {
|
|
utsPair := strings.Split(b, ":")
|
|
if len(utsPair) != 2 {
|
|
return fmt.Errorf("%v is not a <username>:<comma-separated twitter usernames> string", b)
|
|
}
|
|
u := utsPair[0]
|
|
ts := utsPair[1]
|
|
um[u] = append(um[u], strings.Split(ts, ",")...)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func rewriteURL(u string) string {
|
|
p, err := url.Parse(u)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
ps := strings.TrimLeft(p.Path, "/")
|
|
|
|
return path.Join("/media", p.Host, ps)
|
|
}
|
|
|
|
func bestVariant(vs []interface{}) interface{} {
|
|
var bestBitrate float64 = -1
|
|
var bestVariant interface{}
|
|
for _, v := range vs {
|
|
v := v.(map[string]interface{})
|
|
if v["content_type"].(string) == "application/x-mpegURL" {
|
|
// m3u8 isn't very interesting since we'd have to parse it to get a playlist.
|
|
continue
|
|
}
|
|
br := v["bitrate"].(float64)
|
|
if br > bestBitrate {
|
|
bestBitrate = br
|
|
bestVariant = v
|
|
}
|
|
}
|
|
return bestVariant
|
|
}
|
|
|
|
type mapDatastore map[interface{}]interface{}
|
|
|
|
func (ds mapDatastore) Get(key interface{}) (value interface{}, ok bool) {
|
|
value, ok = ds[key]
|
|
return value, ok
|
|
}
|
|
|
|
func (ds mapDatastore) Add(key, value interface{}) {
|
|
ds[key] = value
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
ctx := context.Background()
|
|
|
|
pool, err := pgxpool.Connect(ctx, *databaseURL)
|
|
if err != nil {
|
|
log.Fatalf("pgxpool.Connect: %v", err)
|
|
}
|
|
defer pool.Close()
|
|
|
|
r := mux.NewRouter()
|
|
r.HandleFunc("/healthz", func(rw http.ResponseWriter, r *http.Request) {
|
|
rw.Header().Set("Content-Type", "text/plain")
|
|
fmt.Fprintf(rw, "ok")
|
|
})
|
|
|
|
var authR *mux.Router
|
|
var v *pomerium.Verifier
|
|
if *localDisableAuth == "" {
|
|
authR = r.PathPrefix("/").Subrouter()
|
|
var err error
|
|
v, err = pomerium.New(&pomerium.Options{
|
|
JWKSEndpoint: "https://auth.int.lukegb.com/.well-known/pomerium/jwks.json",
|
|
Datastore: mapDatastore{},
|
|
})
|
|
if err != nil {
|
|
log.Fatalf("pomerium.New: %v", err)
|
|
}
|
|
authR.Use(pomerium.AddIdentityToRequest(v))
|
|
}
|
|
authR = r
|
|
|
|
userFromContext := func(ctx context.Context) (string, error) {
|
|
if *localDisableAuth != "" {
|
|
return *localDisableAuth, nil
|
|
}
|
|
identity, err := pomerium.FromContext(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return identity.User, nil
|
|
}
|
|
|
|
writeError := func(rw http.ResponseWriter, status int, wrap string, err error) {
|
|
log.Printf("Error in HTTP handler: %v: %v", wrap, err)
|
|
rw.WriteHeader(status)
|
|
fmt.Fprintf(rw, "<h1>Oops. Something went wrong.</h1>")
|
|
fmt.Fprintf(rw, "<p>%s</p>", wrap)
|
|
}
|
|
|
|
authR.PathPrefix("/media/").Handler(http.StripPrefix("/media/", http.FileServer(http.Dir(*mediaDirectory))))
|
|
authR.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
user, err := userFromContext(ctx)
|
|
if err != nil {
|
|
http.Error(rw, fmt.Errorf("userFromContext: %w", err).Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
twitterAccounts := userMapping[user]
|
|
|
|
var allTweets, timelineTweets, unfetchedMedia, unfetchedRelated int
|
|
|
|
if err := pool.QueryRow(ctx, `
|
|
SELECT
|
|
(SELECT COUNT(id) FROM tweets) all_tweets,
|
|
(SELECT COUNT(DISTINCT tweetid) FROM user_accounts_tweets WHERE on_timeline) timeline_tweets,
|
|
(SELECT COUNT(id) FROM tweets WHERE NOT fetched_media) not_fetched_media,
|
|
(SELECT COUNT(id) FROM tweets WHERE NOT fetched_related_tweets) not_fetched_related
|
|
`).Scan(&allTweets, &timelineTweets, &unfetchedMedia, &unfetchedRelated); err != nil {
|
|
writeError(rw, http.StatusInternalServerError, "querying stats", err)
|
|
return
|
|
}
|
|
|
|
rows, err := pool.Query(ctx, `
|
|
SELECT
|
|
ua.username,
|
|
COUNT(uat.tweetid) tweet_count,
|
|
(SELECT CAST(object->>'created_at' AS timestamp with time zone) FROM tweets WHERE id=MAX(uat.tweetid)) latest_tweet
|
|
FROM
|
|
user_accounts ua
|
|
LEFT JOIN
|
|
user_accounts_tweets uat ON uat.userid=ua.userid
|
|
WHERE 1=1
|
|
AND ua.username = ANY($1::text[])
|
|
AND uat.on_timeline
|
|
GROUP BY 1
|
|
ORDER BY 1
|
|
`, twitterAccounts)
|
|
if err != nil {
|
|
writeError(rw, http.StatusInternalServerError, "querying database", err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
type twitterData struct {
|
|
Username string
|
|
TweetCount int
|
|
LatestTweet time.Time
|
|
}
|
|
var tds []twitterData
|
|
|
|
for rows.Next() {
|
|
var td twitterData
|
|
if err := rows.Scan(&td.Username, &td.TweetCount, &td.LatestTweet); err != nil {
|
|
writeError(rw, http.StatusInternalServerError, "reading from database", err)
|
|
return
|
|
}
|
|
tds = append(tds, td)
|
|
}
|
|
rows.Close()
|
|
|
|
indexTmpl.Execute(rw, struct {
|
|
Username string
|
|
TwitterAccounts []twitterData
|
|
TotalTweets int
|
|
TimelineTweets int
|
|
UnfetchedMedia int
|
|
UnfetchedRelated int
|
|
}{
|
|
Username: user,
|
|
TwitterAccounts: tds,
|
|
TotalTweets: allTweets,
|
|
TimelineTweets: timelineTweets,
|
|
UnfetchedMedia: unfetchedMedia,
|
|
UnfetchedRelated: unfetchedRelated,
|
|
})
|
|
})
|
|
isAllowedToSee := func(ctx context.Context, twitterUser string) bool {
|
|
user, err := userFromContext(ctx)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
twitterAccounts := userMapping[user]
|
|
for _, a := range twitterAccounts {
|
|
if a == twitterUser {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
toInt := func(s string, def int) int {
|
|
n, err := strconv.ParseInt(s, 10, 0)
|
|
if err != nil {
|
|
return def
|
|
}
|
|
return int(n)
|
|
}
|
|
clamp := func(min, n, max int) int {
|
|
if n < min {
|
|
return min
|
|
} else if n > max {
|
|
return max
|
|
}
|
|
return n
|
|
}
|
|
authR.HandleFunc("/view/{twitterUser}", func(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
vars := mux.Vars(r)
|
|
twitterUser := vars["twitterUser"]
|
|
user, err := userFromContext(ctx)
|
|
if err != nil {
|
|
http.Error(rw, fmt.Errorf("userFromContext: %w", err).Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if !isAllowedToSee(ctx, twitterUser) {
|
|
writeError(rw, http.StatusNotFound, "no such twitter user being archived", fmt.Errorf("user %q attempted to access %q", user, twitterUser))
|
|
return
|
|
}
|
|
|
|
q := r.URL.Query()
|
|
pageSize := clamp(1, toInt(q.Get("page_size"), 20), 200)
|
|
startFrom := toInt(q.Get("start_from"), 0)
|
|
query := q.Get("q")
|
|
|
|
rows, err := pool.Query(ctx, `
|
|
SELECT
|
|
t.id,
|
|
t.text,
|
|
t.object,
|
|
rt.object retweet_object,
|
|
qt.object quote_object,
|
|
CAST(COALESCE(t.object->'retweeted_status'->>'created_at', t.object->>'created_at') AS timestamp with time zone) created_at,
|
|
CAST(qt.object->>'created_at' AS timestamp with time zone) quoted_created_at
|
|
FROM
|
|
user_accounts_tweets uat
|
|
INNER JOIN
|
|
user_accounts ua ON ua.userid=uat.userid
|
|
INNER JOIN
|
|
tweets t ON t.id=uat.tweetid
|
|
LEFT JOIN
|
|
tweets rt ON rt.id=CAST(t.object->'retweeted_status'->>'id' AS BIGINT)
|
|
LEFT JOIN
|
|
tweets qt ON qt.id=CAST(t.object->>'quoted_status_id' AS BIGINT)
|
|
WHERE 1=1
|
|
AND ua.username=$1
|
|
AND uat.on_timeline
|
|
AND ($3::bigint=0 OR t.id <= $3::bigint)
|
|
AND ($4='' OR ($4<>'' AND (to_tsvector('english', t.text) @@ to_tsquery('english', $4) OR to_tsvector('english', t.object->'retweeted_status'->>'full_text') @@ to_tsquery('english', $4) OR t.object->'user'->>'screen_name'=$4)))
|
|
ORDER BY t.id DESC
|
|
LIMIT $2
|
|
`, twitterUser, pageSize+1, startFrom, query)
|
|
if err != nil {
|
|
writeError(rw, http.StatusInternalServerError, "querying database", err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
type tweet struct {
|
|
ID int64
|
|
Text string
|
|
CreatedAt time.Time
|
|
CreatedAtFriendly string
|
|
Object interface{}
|
|
RetweetedObject interface{}
|
|
QuotedObject interface{}
|
|
QuotedCreatedAt time.Time
|
|
QuotedCreatedAtFriendly string
|
|
}
|
|
type twitterData struct {
|
|
TwitterUsername string
|
|
Query string
|
|
Tweets []tweet
|
|
NextTweetID *int64
|
|
|
|
FormatTweetText func(string, interface{}) safehtml.HTML
|
|
}
|
|
pullIndices := func(m map[string]interface{}) [2]int {
|
|
midx := m["indices"].([]interface{})
|
|
midx0 := int(midx[0].(float64))
|
|
midx1 := int(midx[1].(float64))
|
|
return [2]int{midx0, midx1}
|
|
}
|
|
td := twitterData{
|
|
TwitterUsername: twitterUser,
|
|
Query: query,
|
|
FormatTweetText: func(t string, tw interface{}) safehtml.HTML {
|
|
ltRep := string([]rune{0xe000})
|
|
gtRep := string([]rune{0xe001})
|
|
t = strings.ReplaceAll(t, "<", ltRep)
|
|
t = strings.ReplaceAll(t, ">", gtRep)
|
|
|
|
type span struct {
|
|
span [2]int
|
|
whatDo string // remove, link
|
|
|
|
linkTo string // link only
|
|
linkText string // link only
|
|
}
|
|
// Delete native media and add links.
|
|
var spans []span
|
|
obj := tw.(map[string]interface{})
|
|
if ee, ok := obj["extended_entities"].(map[string]interface{}); ok {
|
|
ems := ee["media"].([]interface{})
|
|
for _, emi := range ems {
|
|
em := emi.(map[string]interface{})
|
|
spans = append(spans, span{
|
|
span: pullIndices(em),
|
|
whatDo: "remove",
|
|
})
|
|
}
|
|
}
|
|
if es, ok := obj["entities"].(map[string]interface{}); ok {
|
|
if hts, ok := es["hashtags"].([]interface{}); ok {
|
|
for _, hti := range hts {
|
|
ht := hti.(map[string]interface{})
|
|
htt := ht["text"].(string)
|
|
spans = append(spans, span{
|
|
span: pullIndices(ht),
|
|
whatDo: "link",
|
|
linkTo: fmt.Sprintf("https://twitter.com/hashtag/%s", htt),
|
|
})
|
|
}
|
|
}
|
|
if urls, ok := es["urls"].([]interface{}); ok {
|
|
for _, urli := range urls {
|
|
url := urli.(map[string]interface{})
|
|
urldisp := url["display_url"].(string)
|
|
urlexp := url["expanded_url"].(string)
|
|
spans = append(spans, span{
|
|
span: pullIndices(url),
|
|
whatDo: "link",
|
|
linkTo: urlexp,
|
|
linkText: urldisp,
|
|
})
|
|
}
|
|
}
|
|
if mentions, ok := es["user_mentions"].([]interface{}); ok {
|
|
for _, mentioni := range mentions {
|
|
mention := mentioni.(map[string]interface{})
|
|
mentionuser := mention["screen_name"].(string)
|
|
spans = append(spans, span{
|
|
span: pullIndices(mention),
|
|
whatDo: "link",
|
|
linkTo: fmt.Sprintf("https://twitter.com/%s", mentionuser),
|
|
})
|
|
}
|
|
}
|
|
if symbols, ok := es["symbols"].([]interface{}); ok {
|
|
for _, symboli := range symbols {
|
|
symbol := symboli.(map[string]interface{})
|
|
symbolname := symbol["text"].(string)
|
|
spans = append(spans, span{
|
|
span: pullIndices(symbol),
|
|
whatDo: "link",
|
|
linkTo: fmt.Sprintf("?q=$%s", symbolname),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
// Sort removeSpans from the end to the beginning.
|
|
sort.Slice(spans, func(a, b int) bool {
|
|
return spans[a].span[0] > spans[b].span[0]
|
|
})
|
|
// Expand overlapping remove spans.
|
|
newSpans := make([]span, 0, len(spans))
|
|
for i := 0; i < len(spans)-1; i++ {
|
|
span := spans[i]
|
|
prevSpan := spans[i+1]
|
|
if prevSpan.span[0] <= span.span[0] && prevSpan.span[1] >= span.span[1] {
|
|
// Spans overlap.
|
|
if span.whatDo != "remove" || prevSpan.whatDo != "remove" {
|
|
log.Printf("found overlapping non-remove spans!")
|
|
}
|
|
if span.span[1] > prevSpan.span[1] {
|
|
prevSpan.span[1] = span.span[1]
|
|
}
|
|
continue
|
|
}
|
|
newSpans = append(newSpans, span)
|
|
}
|
|
if len(spans) > 0 {
|
|
newSpans = append(newSpans, spans[len(spans)-1])
|
|
}
|
|
spans = newSpans
|
|
runed := []rune(t)
|
|
for _, span := range spans {
|
|
switch span.whatDo {
|
|
case "remove":
|
|
// Delete text from span[0] to span[1].
|
|
runed = append(runed[:span.span[0]], runed[span.span[1]:]...)
|
|
case "link":
|
|
// Add a link.
|
|
var text []rune
|
|
if span.linkText == "" {
|
|
text = runed[span.span[0]:span.span[1]]
|
|
} else {
|
|
text = []rune(span.linkText)
|
|
}
|
|
runedBits := [][]rune{
|
|
runed[:span.span[0]],
|
|
[]rune("<a href=\""),
|
|
[]rune(span.linkTo),
|
|
[]rune("\">"),
|
|
[]rune(text),
|
|
[]rune("</a>"),
|
|
runed[span.span[1]:],
|
|
}
|
|
finalLen := 0
|
|
for _, s := range runedBits {
|
|
finalLen += len(s)
|
|
}
|
|
runed = make([]rune, finalLen)
|
|
p := 0
|
|
for _, s := range runedBits {
|
|
p += copy(runed[p:], s)
|
|
}
|
|
default:
|
|
log.Printf("unknown span operation %v", span.whatDo)
|
|
}
|
|
}
|
|
t = string(runed)
|
|
|
|
// HTML escape any <>
|
|
t = strings.ReplaceAll(t, ltRep, "<")
|
|
t = strings.ReplaceAll(t, gtRep, ">")
|
|
|
|
return uncheckedconversions.HTMLFromStringKnownToSatisfyTypeContract(t)
|
|
},
|
|
}
|
|
now := time.Now()
|
|
for rows.Next() {
|
|
var t tweet
|
|
var o, ro, qo pgtype.JSONB
|
|
var qca pgtype.Timestamptz
|
|
if err := rows.Scan(&t.ID, &t.Text, &o, &ro, &qo, &t.CreatedAt, &qca); err != nil {
|
|
writeError(rw, http.StatusInternalServerError, "reading from database", err)
|
|
return
|
|
}
|
|
if o.Status == pgtype.Present {
|
|
var oo interface{}
|
|
if err := json.Unmarshal(o.Bytes, &oo); err != nil {
|
|
writeError(rw, http.StatusInternalServerError, "parsing JSON from database (object)", err)
|
|
return
|
|
}
|
|
t.Object = oo
|
|
t.CreatedAtFriendly = agoFriendly(now, t.CreatedAt)
|
|
}
|
|
if ro.Status == pgtype.Present {
|
|
var oo interface{}
|
|
if err := json.Unmarshal(ro.Bytes, &oo); err != nil {
|
|
writeError(rw, http.StatusInternalServerError, "parsing JSON from database (retweeted object)", err)
|
|
return
|
|
}
|
|
t.RetweetedObject = oo
|
|
}
|
|
if qo.Status == pgtype.Present {
|
|
var oo interface{}
|
|
if err := json.Unmarshal(qo.Bytes, &oo); err != nil {
|
|
writeError(rw, http.StatusInternalServerError, "parsing JSON from database (quoted object)", err)
|
|
return
|
|
}
|
|
t.QuotedObject = oo
|
|
}
|
|
if qca.Status == pgtype.Present {
|
|
t.QuotedCreatedAt = qca.Time
|
|
t.QuotedCreatedAtFriendly = agoFriendly(now, qca.Time)
|
|
}
|
|
td.Tweets = append(td.Tweets, t)
|
|
}
|
|
rows.Close()
|
|
|
|
if len(td.Tweets) > pageSize {
|
|
td.NextTweetID = &td.Tweets[pageSize].ID
|
|
td.Tweets = td.Tweets[:pageSize]
|
|
}
|
|
|
|
if err := tweetsTmpl.Execute(rw, td); err != nil {
|
|
log.Printf("tweets: executing template: %v", err)
|
|
}
|
|
})
|
|
|
|
log.Printf("now listening on :8080")
|
|
log.Print(http.ListenAndServe(":8080", r))
|
|
}
|
|
func agoFriendly(now, at time.Time) string {
|
|
ago := now.Sub(at)
|
|
switch {
|
|
case at.Year() != now.Year():
|
|
return at.Format("Jan 2, 2006")
|
|
case at.YearDay() != now.YearDay():
|
|
return at.Format("Jan 2")
|
|
case ago.Hours() >= 1.0:
|
|
return fmt.Sprintf("%dh", int(ago.Hours()))
|
|
case ago.Minutes() >= 1.0:
|
|
return fmt.Sprintf("%dm", int(ago.Minutes()))
|
|
case ago.Seconds() >= 0.0:
|
|
return fmt.Sprintf("%ds", int(ago.Seconds()))
|
|
default:
|
|
return fmt.Sprintf("in %ds", -int(ago.Seconds()))
|
|
}
|
|
}
|