// SPDX-FileCopyrightText: 2020 Luke Granger-Brown <depot@lukegb.com>
//
// SPDX-License-Identifier: Apache-2.0
package main

import (
	"context"
	"encoding/json"
	"flag"
	"fmt"
	"log"
	"net/http"
	"net/url"
	"path"
	"sort"
	"strconv"
	"strings"
	"time"

	"github.com/google/safehtml"
	"github.com/google/safehtml/template"
	"github.com/google/safehtml/uncheckedconversions"
	"github.com/gorilla/mux"
	"github.com/jackc/pgtype"
	"github.com/jackc/pgx/v4/pgxpool"
	pomerium "github.com/pomerium/sdk-go"
	"gocloud.dev/blob"
	"gocloud.dev/gcerrors"

	_ "gocloud.dev/blob/s3blob"
)

var (
	listen           = flag.String("listen", ":8080", "Listen address")
	databaseURL      = flag.String("database_url", "", "Database URL")
	userMapping      = userMap{}
	localDisableAuth = flag.String("local_auth_override_user", "", "Disable authn/authz - used for dev - set to username")
	mediaURL         = flag.String("media_url", "s3://public-lukegb-twitterchiver?endpoint=objdump.zxcvbnm.ninja&region=london", "Blob URL")

	funcMap = template.FuncMap{
		"rewriteURL":  rewriteURL,
		"bestVariant": bestVariant,
	}
	indexTmpl  = template.Must(template.New("index.html").Funcs(funcMap).ParseFiles("templates/index.html"))
	tweetsTmpl = template.Must(template.New("tweets.html").Funcs(funcMap).ParseFiles("templates/tweets.html"))
)

const (
	redirectExpiry = 5 * time.Minute
)

func init() {
	flag.Var(userMapping, "user_to_twitter", "Space-separated list of <username>:<comma-separated Twitter usernames>")
}

type userMap map[string][]string

func (um userMap) String() string {
	var bits []string
	for u, ts := range um {
		ts2 := make([]string, len(ts))
		copy(ts2, ts)
		sort.Strings(ts2)
		bits = append(bits, fmt.Sprintf("%s:%s", u, strings.Join(ts2, ",")))
	}
	sort.Strings(bits)
	return strings.Join(bits, " ")
}

func (um userMap) Set(v string) error {
	bits := strings.Split(v, " ")
	for _, b := range bits {
		utsPair := strings.Split(b, ":")
		if len(utsPair) != 2 {
			return fmt.Errorf("%v is not a <username>:<comma-separated twitter usernames> string", b)
		}
		u := utsPair[0]
		ts := utsPair[1]
		um[u] = append(um[u], strings.Split(ts, ",")...)
	}
	return nil
}

func rewriteURL(u string) string {
	p, err := url.Parse(u)
	if err != nil {
		return ""
	}
	ps := strings.TrimLeft(p.Path, "/")

	return path.Join("/media", p.Host, ps)
}

func bestVariant(vs []interface{}) interface{} {
	var bestBitrate float64 = -1
	var bestVariant interface{}
	for _, v := range vs {
		v := v.(map[string]interface{})
		if v["content_type"].(string) == "application/x-mpegURL" {
			// m3u8 isn't very interesting since we'd have to parse it to get a playlist.
			continue
		}
		br := v["bitrate"].(float64)
		if br > bestBitrate {
			bestBitrate = br
			bestVariant = v
		}
	}
	return bestVariant
}

type mapDatastore map[interface{}]interface{}

func (ds mapDatastore) Get(key interface{}) (value interface{}, ok bool) {
	value, ok = ds[key]
	return value, ok
}

func (ds mapDatastore) Add(key, value interface{}) {
	ds[key] = value
}

func main() {
	flag.Parse()
	ctx := context.Background()

	mediaBucket, err := blob.OpenBucket(ctx, *mediaURL)
	if err != nil {
		log.Fatalf("blob.OpenBucket: %v", err)
	}

	pool, err := pgxpool.Connect(ctx, *databaseURL)
	if err != nil {
		log.Fatalf("pgxpool.Connect: %v", err)
	}
	defer pool.Close()

	r := mux.NewRouter()
	r.HandleFunc("/healthz", func(rw http.ResponseWriter, r *http.Request) {
		rw.Header().Set("Content-Type", "text/plain")
		fmt.Fprintf(rw, "ok")
	})

	var authR *mux.Router
	var v *pomerium.Verifier
	if *localDisableAuth == "" {
		authR = r.PathPrefix("/").Subrouter()
		var err error
		v, err = pomerium.New(&pomerium.Options{
			JWKSEndpoint: "https://auth.int.lukegb.com/.well-known/pomerium/jwks.json",
			Datastore:    mapDatastore{},
		})
		if err != nil {
			log.Fatalf("pomerium.New: %v", err)
		}
		authR.Use(pomerium.AddIdentityToRequest(v))
		log.Println("configured pomerium middleware")
	} else {
		authR = r
	}

	userFromContext := func(ctx context.Context) (string, error) {
		if *localDisableAuth != "" {
			return *localDisableAuth, nil
		}
		identity, err := pomerium.FromContext(ctx)
		if err != nil {
			return "", err
		}
		return identity.Email, nil
	}

	writeError := func(rw http.ResponseWriter, status int, wrap string, err error) {
		log.Printf("Error in HTTP handler: %v: %v", wrap, err)
		rw.WriteHeader(status)
		fmt.Fprintf(rw, "<h1>Oops. Something went wrong.</h1>")
		fmt.Fprintf(rw, "<p>%s</p>", wrap)
	}

	authR.PathPrefix("/media/{filepath:.*}").HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
		ctx := r.Context()
		vars := mux.Vars(r)
		_, err := userFromContext(ctx)
		if err != nil {
			http.Error(rw, fmt.Errorf("userFromContext: %w", err).Error(), http.StatusInternalServerError)
			return
		}
		u, err := mediaBucket.SignedURL(ctx, vars["filepath"], &blob.SignedURLOptions{
			Expiry: redirectExpiry,
		})
		switch gcerrors.Code(err) {
		case gcerrors.NotFound:
			// This is unlikely to get returned.
			http.Error(rw, "not found", http.StatusNotFound)
			return
		case gcerrors.OK:
			http.Redirect(rw, r, u, http.StatusFound)
			return
		default:
			log.Printf("media SignedURL: %v", vars["filepath"], err)
		}
		http.Error(rw, err.Error(), http.StatusInternalServerError)
	})
	authR.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
		ctx := r.Context()
		user, err := userFromContext(ctx)
		if err != nil {
			http.Error(rw, fmt.Errorf("userFromContext: %w", err).Error(), http.StatusInternalServerError)
			return
		}
		twitterAccounts := userMapping[user]

		var allTweets, timelineTweets, unfetchedMedia, unfetchedRelated int

		if err := pool.QueryRow(ctx, `
SELECT
	(SELECT COUNT(id) FROM tweets) all_tweets,
	(SELECT COUNT(DISTINCT tweetid) FROM user_accounts_tweets WHERE on_timeline) timeline_tweets,
	(SELECT COUNT(id) FROM tweets WHERE NOT fetched_media) not_fetched_media,
	(SELECT COUNT(id) FROM tweets WHERE NOT fetched_related_tweets) not_fetched_related
`).Scan(&allTweets, &timelineTweets, &unfetchedMedia, &unfetchedRelated); err != nil {
			writeError(rw, http.StatusInternalServerError, "querying stats", err)
			return
		}

		rows, err := pool.Query(ctx, `
SELECT
	ua.username,
	COUNT(uat.tweetid) tweet_count,
	(SELECT CAST(object->>'created_at' AS timestamp with time zone) FROM tweets WHERE id=MAX(uat.tweetid)) latest_tweet
FROM
	user_accounts ua
LEFT JOIN
	user_accounts_tweets uat ON uat.userid=ua.userid
WHERE 1=1
	AND ua.username = ANY($1::text[])
	AND uat.on_timeline
GROUP BY 1
ORDER BY 1
`, twitterAccounts)
		if err != nil {
			writeError(rw, http.StatusInternalServerError, "querying database", err)
			return
		}
		defer rows.Close()

		type twitterData struct {
			Username    string
			TweetCount  int
			LatestTweet time.Time
		}
		var tds []twitterData

		for rows.Next() {
			var td twitterData
			if err := rows.Scan(&td.Username, &td.TweetCount, &td.LatestTweet); err != nil {
				writeError(rw, http.StatusInternalServerError, "reading from database", err)
				return
			}
			tds = append(tds, td)
		}
		rows.Close()

		indexTmpl.Execute(rw, struct {
			Username         string
			TwitterAccounts  []twitterData
			TotalTweets      int
			TimelineTweets   int
			UnfetchedMedia   int
			UnfetchedRelated int
		}{
			Username:         user,
			TwitterAccounts:  tds,
			TotalTweets:      allTweets,
			TimelineTweets:   timelineTweets,
			UnfetchedMedia:   unfetchedMedia,
			UnfetchedRelated: unfetchedRelated,
		})
	})
	isAllowedToSee := func(ctx context.Context, twitterUser string) bool {
		user, err := userFromContext(ctx)
		if err != nil {
			return false
		}
		twitterAccounts := userMapping[user]
		for _, a := range twitterAccounts {
			if a == twitterUser {
				return true
			}
		}
		return false
	}
	toInt := func(s string, def int) int {
		n, err := strconv.ParseInt(s, 10, 0)
		if err != nil {
			return def
		}
		return int(n)
	}
	clamp := func(min, n, max int) int {
		if n < min {
			return min
		} else if n > max {
			return max
		}
		return n
	}
	authR.HandleFunc("/view/{twitterUser}", func(rw http.ResponseWriter, r *http.Request) {
		ctx := r.Context()
		vars := mux.Vars(r)
		twitterUser := vars["twitterUser"]
		user, err := userFromContext(ctx)
		if err != nil {
			http.Error(rw, fmt.Errorf("userFromContext: %w", err).Error(), http.StatusInternalServerError)
			return
		}
		if !isAllowedToSee(ctx, twitterUser) {
			writeError(rw, http.StatusNotFound, "no such twitter user being archived", fmt.Errorf("user %q attempted to access %q", user, twitterUser))
			return
		}

		q := r.URL.Query()
		pageSize := clamp(1, toInt(q.Get("page_size"), 20), 200)
		startFrom := toInt(q.Get("start_from"), 0)
		query := q.Get("q")

		rows, err := pool.Query(ctx, `
SELECT
	t.id,
	t.text,
	t.object,
	rt.object retweet_object,
	qt.object quote_object,
	CAST(COALESCE(t.object->'retweeted_status'->>'created_at', t.object->>'created_at') AS timestamp with time zone) created_at,
	CAST(qt.object->>'created_at' AS timestamp with time zone) quoted_created_at
FROM
	user_accounts_tweets uat
INNER JOIN
	user_accounts ua ON ua.userid=uat.userid
INNER JOIN
	tweets t ON t.id=uat.tweetid
LEFT JOIN
	tweets rt ON rt.id=CAST(t.object->'retweeted_status'->>'id' AS BIGINT)
LEFT JOIN
	tweets qt ON qt.id=CAST(t.object->>'quoted_status_id' AS BIGINT)
WHERE 1=1
	AND ua.username=$1
	AND uat.on_timeline
	AND ($3::bigint=0 OR t.id <= $3::bigint)
	AND ($4='' OR ($4<>'' AND (to_tsvector('english', t.text) @@ to_tsquery('english', $4) OR to_tsvector('english', t.object->'retweeted_status'->>'full_text') @@ to_tsquery('english', $4) OR t.object->'user'->>'screen_name'=$4)))
ORDER BY t.id DESC
LIMIT $2
`, twitterUser, pageSize+1, startFrom, query)
		if err != nil {
			writeError(rw, http.StatusInternalServerError, "querying database", err)
			return
		}
		defer rows.Close()

		type tweet struct {
			ID                      int64
			Text                    string
			CreatedAt               time.Time
			CreatedAtFriendly       string
			Object                  interface{}
			RetweetedObject         interface{}
			QuotedObject            interface{}
			QuotedCreatedAt         time.Time
			QuotedCreatedAtFriendly string
		}
		type twitterData struct {
			TwitterUsername string
			Query           string
			Tweets          []tweet
			NextTweetID     *int64

			FormatTweetText func(string, interface{}) safehtml.HTML
		}
		pullIndices := func(m map[string]interface{}) [2]int {
			midx := m["indices"].([]interface{})
			midx0 := int(midx[0].(float64))
			midx1 := int(midx[1].(float64))
			return [2]int{midx0, midx1}
		}
		td := twitterData{
			TwitterUsername: twitterUser,
			Query:           query,
			FormatTweetText: func(t string, tw interface{}) safehtml.HTML {
				ltRep := string([]rune{0xe000})
				gtRep := string([]rune{0xe001})
				t = strings.ReplaceAll(t, "<", ltRep)
				t = strings.ReplaceAll(t, ">", gtRep)

				type span struct {
					span   [2]int
					whatDo string // remove, link

					linkTo   string // link only
					linkText string // link only
				}
				// Delete native media and add links.
				var spans []span
				obj := tw.(map[string]interface{})
				if ee, ok := obj["extended_entities"].(map[string]interface{}); ok {
					ems := ee["media"].([]interface{})
					for _, emi := range ems {
						em := emi.(map[string]interface{})
						spans = append(spans, span{
							span:   pullIndices(em),
							whatDo: "remove",
						})
					}
				}
				if es, ok := obj["entities"].(map[string]interface{}); ok {
					if hts, ok := es["hashtags"].([]interface{}); ok {
						for _, hti := range hts {
							ht := hti.(map[string]interface{})
							htt := ht["text"].(string)
							spans = append(spans, span{
								span:   pullIndices(ht),
								whatDo: "link",
								linkTo: fmt.Sprintf("https://twitter.com/hashtag/%s", htt),
							})
						}
					}
					if urls, ok := es["urls"].([]interface{}); ok {
						for _, urli := range urls {
							url := urli.(map[string]interface{})
							urldisp := url["display_url"].(string)
							urlexp := url["expanded_url"].(string)
							spans = append(spans, span{
								span:     pullIndices(url),
								whatDo:   "link",
								linkTo:   urlexp,
								linkText: urldisp,
							})
						}
					}
					if mentions, ok := es["user_mentions"].([]interface{}); ok {
						for _, mentioni := range mentions {
							mention := mentioni.(map[string]interface{})
							mentionuser := mention["screen_name"].(string)
							spans = append(spans, span{
								span:   pullIndices(mention),
								whatDo: "link",
								linkTo: fmt.Sprintf("https://twitter.com/%s", mentionuser),
							})
						}
					}
					if symbols, ok := es["symbols"].([]interface{}); ok {
						for _, symboli := range symbols {
							symbol := symboli.(map[string]interface{})
							symbolname := symbol["text"].(string)
							spans = append(spans, span{
								span:   pullIndices(symbol),
								whatDo: "link",
								linkTo: fmt.Sprintf("?q=$%s", symbolname),
							})
						}
					}
				}
				// Sort removeSpans from the end to the beginning.
				sort.Slice(spans, func(a, b int) bool {
					return spans[a].span[0] > spans[b].span[0]
				})
				// Expand overlapping remove spans.
				newSpans := make([]span, 0, len(spans))
				for i := 0; i < len(spans)-1; i++ {
					span := spans[i]
					prevSpan := spans[i+1]
					if prevSpan.span[0] <= span.span[0] && prevSpan.span[1] >= span.span[1] {
						// Spans overlap.
						if span.whatDo != "remove" || prevSpan.whatDo != "remove" {
							log.Printf("found overlapping non-remove spans!")
						}
						if span.span[1] > prevSpan.span[1] {
							prevSpan.span[1] = span.span[1]
						}
						continue
					}
					newSpans = append(newSpans, span)
				}
				if len(spans) > 0 {
					newSpans = append(newSpans, spans[len(spans)-1])
				}
				spans = newSpans
				runed := []rune(t)
				for _, span := range spans {
					switch span.whatDo {
					case "remove":
						// Delete text from span[0] to span[1].
						runed = append(runed[:span.span[0]], runed[span.span[1]:]...)
					case "link":
						// Add a link.
						var text []rune
						if span.linkText == "" {
							text = runed[span.span[0]:span.span[1]]
						} else {
							text = []rune(span.linkText)
						}
						runedBits := [][]rune{
							runed[:span.span[0]],
							[]rune("<a href=\""),
							[]rune(span.linkTo),
							[]rune("\">"),
							[]rune(text),
							[]rune("</a>"),
							runed[span.span[1]:],
						}
						finalLen := 0
						for _, s := range runedBits {
							finalLen += len(s)
						}
						runed = make([]rune, finalLen)
						p := 0
						for _, s := range runedBits {
							p += copy(runed[p:], s)
						}
					default:
						log.Printf("unknown span operation %v", span.whatDo)
					}
				}
				t = string(runed)

				// HTML escape any <>
				t = strings.ReplaceAll(t, ltRep, "&lt;")
				t = strings.ReplaceAll(t, gtRep, "&gt;")

				return uncheckedconversions.HTMLFromStringKnownToSatisfyTypeContract(t)
			},
		}
		now := time.Now()
		for rows.Next() {
			var t tweet
			var o, ro, qo pgtype.JSONB
			var qca pgtype.Timestamptz
			if err := rows.Scan(&t.ID, &t.Text, &o, &ro, &qo, &t.CreatedAt, &qca); err != nil {
				writeError(rw, http.StatusInternalServerError, "reading from database", err)
				return
			}
			if o.Status == pgtype.Present {
				var oo interface{}
				if err := json.Unmarshal(o.Bytes, &oo); err != nil {
					writeError(rw, http.StatusInternalServerError, "parsing JSON from database (object)", err)
					return
				}
				t.Object = oo
				t.CreatedAtFriendly = agoFriendly(now, t.CreatedAt)
			}
			if ro.Status == pgtype.Present {
				var oo interface{}
				if err := json.Unmarshal(ro.Bytes, &oo); err != nil {
					writeError(rw, http.StatusInternalServerError, "parsing JSON from database (retweeted object)", err)
					return
				}
				t.RetweetedObject = oo
			}
			if qo.Status == pgtype.Present {
				var oo interface{}
				if err := json.Unmarshal(qo.Bytes, &oo); err != nil {
					writeError(rw, http.StatusInternalServerError, "parsing JSON from database (quoted object)", err)
					return
				}
				t.QuotedObject = oo
			}
			if qca.Status == pgtype.Present {
				t.QuotedCreatedAt = qca.Time
				t.QuotedCreatedAtFriendly = agoFriendly(now, qca.Time)
			}
			td.Tweets = append(td.Tweets, t)
		}
		rows.Close()

		if len(td.Tweets) > pageSize {
			td.NextTweetID = &td.Tweets[pageSize].ID
			td.Tweets = td.Tweets[:pageSize]
		}

		if err := tweetsTmpl.Execute(rw, td); err != nil {
			log.Printf("tweets: executing template: %v", err)
		}
	})

	log.Printf("now listening on %s", *listen)
	log.Print(http.ListenAndServe(*listen, r))
}
func agoFriendly(now, at time.Time) string {
	ago := now.Sub(at)
	switch {
	case at.Year() != now.Year():
		return at.Format("Jan 2, 2006")
	case at.YearDay() != now.YearDay():
		return at.Format("Jan 2")
	case ago.Hours() >= 1.0:
		return fmt.Sprintf("%dh", int(ago.Hours()))
	case ago.Minutes() >= 1.0:
		return fmt.Sprintf("%dm", int(ago.Minutes()))
	case ago.Seconds() >= 0.0:
		return fmt.Sprintf("%ds", int(ago.Seconds()))
	default:
		return fmt.Sprintf("in %ds", -int(ago.Seconds()))
	}
}