package main

import (
	"context"
	"crypto/rand"
	"crypto/sha256"
	"encoding/hex"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"io"
	"io/fs"
	"log"
	mathrand "math/rand"
	"net"
	"net/http"
	"os"
	"os/signal"
	"path/filepath"
	"strings"

	"golang.org/x/oauth2"
)

var (
	debugOnlyAssumeEmail = flag.String("debug_only_assume_email", "", "If set, ignores the x-pomerium-claim-email header and just uses this static address.")
	addr                 = flag.String("addr", "localhost:10908", "Port to listen on.")
	dataDir              = flag.String("data_dir", os.Getenv("STATE_DIRECTORY"), "Data directory to store data in.")

	oauthClientID     = flag.String("oauth_client_id", os.Getenv("OAUTH_CLIENT_ID"), "Sets the OAuth client ID to use.")
	oauthClientSecret = flag.String("oauth_client_secret", os.Getenv("OAUTH_CLIENT_SECRET"), "Sets the OAuth client secret to use.")
	baseURL           = flag.String("base_url", "", "Sets my publically-visible base domain.")

	httpMux = http.NewServeMux()
)

type Post struct {
	PostURL string `json:"post_url"`
}

type User struct {
	originPath string
	email      string

	OAuthToken *oauth2.Token `json:"oauth_token"`
	Likes      []Post        `json:"likes"`
}

func (u *User) save() error {
	p := u.originPath
	if p == "" {
		return fmt.Errorf("originPath somehow unset")
	}

	bs, err := json.Marshal(u)
	if err != nil {
		return err
	}
	return os.WriteFile(p, bs, 0600)
}

func pathForUser(email, dataDir string) string {
	sum := sha256.Sum256([]byte(email))
	return filepath.Join(dataDir, hex.EncodeToString(sum[:])+".json")
}

func loadUser(email, dataDir string) (*User, error) {
	p := pathForUser(email, dataDir)
	u := &User{originPath: p, email: email}

	d, err := os.ReadFile(p)
	if errors.Is(err, fs.ErrNotExist) {
		log.Printf("generating new user for %v (path= %v)", email, p)
		return u, nil
	}
	if err := json.Unmarshal(d, u); err != nil {
		return nil, fmt.Errorf("unmarshalling JSON from %v: %w", p, err)
	}
	log.Printf("loaded user for %v from %v", email, p)
	return u, nil
}

const (
	pomeriumEmailHeader = "x-pomerium-claim-email"
	oauthStateCookie    = "tumblr-oauth-state"
)

type contextKeyType struct{}

var (
	userContextKey = contextKeyType{}
)

func user(ctx context.Context) (*User, bool) {
	u, ok := ctx.Value(userContextKey).(*User)
	return u, ok
}

type app struct {
	oauthConfig          *oauth2.Config
	dataDir              string
	debugOnlyAssumeEmail string
}

func (a *app) loadUserMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
		email := a.debugOnlyAssumeEmail
		if email == "" {
			email = r.Header.Get(pomeriumEmailHeader)
		}
		if email == "" {
			http.Error(rw, "no email", http.StatusForbidden)
			return
		}

		u, err := loadUser(email, a.dataDir)
		if err != nil {
			http.Error(rw, err.Error(), http.StatusInternalServerError)
			return
		}

		next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), userContextKey, u)))
	})
}

func (a *app) performLogin(ctx context.Context, u *User, rw http.ResponseWriter, r *http.Request) {
	randData := make([]byte, 32)
	if n, err := rand.Read(randData); err != nil {
		http.Error(rw, "bad random: "+err.Error(), http.StatusInternalServerError)
		return
	} else if n != len(randData) {
		http.Error(rw, fmt.Sprintf("bad random: got %d of random not %d", n, len(randData)), http.StatusInternalServerError)
		return
	}
	stateVal := hex.EncodeToString(randData)
	http.SetCookie(rw, &http.Cookie{
		Name:     oauthStateCookie,
		Value:    stateVal,
		HttpOnly: true,
		SameSite: http.SameSiteLaxMode,
	})
	http.Redirect(rw, r, a.oauthConfig.AuthCodeURL(stateVal, oauth2.AccessTypeOnline), http.StatusSeeOther)
}

func (a *app) handleOAuth(ctx context.Context, u *User, rw http.ResponseWriter, r *http.Request) {
	cookie, err := r.Cookie(oauthStateCookie)
	if err != nil {
		http.Error(rw, fmt.Sprintf("getting state cookie: %v", err), http.StatusBadRequest)
		return
	}
	if cookie.Value != r.URL.Query().Get("state") {
		http.Error(rw, "invalid state", http.StatusBadRequest)
		return
	}

	tok, err := a.oauthConfig.Exchange(ctx, r.URL.Query().Get("code"))
	if err != nil {
		http.Error(rw, fmt.Sprintf("oauth2 exchange: %v", err), http.StatusInternalServerError)
		return
	}

	u.OAuthToken = tok
	if err := u.save(); err != nil {
		http.Error(rw, fmt.Sprintf("persisting user: %v", err), http.StatusInternalServerError)
		return
	}

	http.Redirect(rw, r, "/refresh", http.StatusSeeOther)
}

func (a *app) redirectToLogin(rw http.ResponseWriter, r *http.Request) {
	http.Redirect(rw, r, "/login", http.StatusSeeOther)
}

func (a *app) requireOAuthMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
		ctx := r.Context()
		u, ok := user(ctx)
		if !ok {
			http.Error(rw, "user missing", http.StatusInternalServerError)
		} else if r.URL.Path == "/login" {
			a.performLogin(ctx, u, rw, r)
		} else if r.URL.Path == "/oauth" {
			a.handleOAuth(ctx, u, rw, r)
		} else if u == nil || (len(u.Likes) == 0 && (u.OAuthToken == nil || !u.OAuthToken.Valid())) {
			a.redirectToLogin(rw, r)
		} else {
			next.ServeHTTP(rw, r)
		}
	})
}

type likesResponse struct {
	Response struct {
		LikedPosts []Post `json:"liked_posts"`
		Links      struct {
			Next struct {
				Href string `json:"href"`
			} `json:"next"`
		} `json:"_links"`
	} `json:"response"`
}

func (a *app) refreshLikes(rw http.ResponseWriter, r *http.Request) {
	ctx := r.Context()
	u, ok := user(ctx)
	if !ok || !u.OAuthToken.Valid() {
		a.redirectToLogin(rw, r)
		return
	}

	client := a.oauthConfig.Client(ctx, u.OAuthToken)

	var likes []Post
	const urlPrefix = "https://api.tumblr.com"
	urlSuffix := "/v2/user/likes"
	for {
		req, err := http.NewRequestWithContext(ctx, "GET", urlPrefix+urlSuffix, nil)
		if err != nil {
			http.Error(rw, fmt.Sprintf("formulating likes request: %v", err), http.StatusInternalServerError)
			return
		}
		fmt.Println(req.URL)
		resp, err := client.Do(req)
		if err != nil {
			http.Error(rw, fmt.Sprintf("performing likes request: %v", err), http.StatusInternalServerError)
			return
		}
		defer resp.Body.Close()
		if resp.StatusCode != http.StatusOK {
			errBody, _ := io.ReadAll(resp.Body)
			http.Error(rw, fmt.Sprintf("likes request status %v\n\n%v", resp.StatusCode, string(errBody)), http.StatusInternalServerError)
			return
		}

		var r likesResponse
		if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
			http.Error(rw, fmt.Sprintf("decoding likes body as json: %v", err), http.StatusInternalServerError)
			return
		}
		likes = append(likes, r.Response.LikedPosts...)

		urlSuffix = r.Response.Links.Next.Href
		if urlSuffix == "" {
			break
		}
	}

	u.Likes = likes
	if err := u.save(); err != nil {
		http.Error(rw, fmt.Sprintf("saving likes: %v", err), http.StatusInternalServerError)
		return
	}

	http.Redirect(rw, r, "/?refreshed=true", http.StatusSeeOther)
	return
}

func (a *app) index(rw http.ResponseWriter, r *http.Request) {
	ctx := r.Context()
	u, ok := user(ctx)
	if !ok {
		a.redirectToLogin(rw, r)
		return
	}

	fmt.Fprintf(rw, `
<!DOCTYPE html>
<h1>%d likes</h1>
<h2><a href="/random">random</a></h2>
<br>
<br>
<br>
<br>
<a href="/refresh">refresh</a>
`, len(u.Likes))
}

func (a *app) random(rw http.ResponseWriter, r *http.Request) {
	ctx := r.Context()
	u, ok := user(ctx)
	if !ok {
		a.redirectToLogin(rw, r)
		return
	}
	if len(u.Likes) == 0 {
		http.Redirect(rw, r, "/", http.StatusSeeOther)
		return
	}

	n := mathrand.Intn(len(u.Likes))
	l := u.Likes[n]
	http.Redirect(rw, r, l.PostURL, http.StatusSeeOther)
}

func main() {
	flag.Parse()
	log.Println("using oauth client ID", *oauthClientID)
	a := &app{
		oauthConfig: &oauth2.Config{
			ClientID:     *oauthClientID,
			ClientSecret: *oauthClientSecret,
			Endpoint: oauth2.Endpoint{
				AuthURL:  "https://www.tumblr.com/oauth2/authorize",
				TokenURL: "https://api.tumblr.com/v2/oauth2/token",
			},
			Scopes:      []string{"basic"},
			RedirectURL: fmt.Sprintf("%s/oauth", *baseURL),
		},
		dataDir:              *dataDir,
		debugOnlyAssumeEmail: *debugOnlyAssumeEmail,
	}
	httpMux.HandleFunc("/refresh", a.refreshLikes)
	httpMux.HandleFunc("/", a.index)
	httpMux.HandleFunc("/random", a.random)

	s := &http.Server{
		Handler: a.loadUserMiddleware(a.requireOAuthMiddleware(httpMux)),
	}

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	for _, singleAddr := range strings.Split(*addr, ",") {
		singleAddr := singleAddr
		go func() {
			l, err := net.Listen("tcp", singleAddr)
			if err != nil {
				log.Printf("listen on %v: %v", singleAddr, err)
				cancel()
				return
			}
			log.Printf("starting to serve on %v", singleAddr)
			if err := s.Serve(l); err != http.ErrServerClosed {
				log.Printf("HTTP server ListenAndServe: %v", err)
				cancel()
			}
		}()
	}

	ctx, stop := signal.NotifyContext(ctx, os.Interrupt)
	defer stop()

	select {
	case <-ctx.Done():
		log.Println("shutting down gracefully (SIGINT again to stop immediately)")
		stop()
		if err := s.Shutdown(context.Background()); err != nil {
			log.Printf("HTTP server Shutdown: %v", err)
		}
	}
}