depot/go/tumblrandom/tumblrandom.go

358 lines
9.1 KiB
Go

package main
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"io/fs"
"log"
"math/big"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"strings"
"golang.org/x/oauth2"
)
var (
debugOnlyAssumeEmail = flag.String("debug_only_assume_email", "", "If set, ignores the x-pomerium-claim-email header and just uses this static address.")
addr = flag.String("addr", "localhost:10908", "Port to listen on.")
dataDir = flag.String("data_dir", os.Getenv("STATE_DIRECTORY"), "Data directory to store data in.")
oauthClientID = flag.String("oauth_client_id", os.Getenv("OAUTH_CLIENT_ID"), "Sets the OAuth client ID to use.")
oauthClientSecret = flag.String("oauth_client_secret", os.Getenv("OAUTH_CLIENT_SECRET"), "Sets the OAuth client secret to use.")
baseURL = flag.String("base_url", "", "Sets my publically-visible base domain.")
httpMux = http.NewServeMux()
)
type Post struct {
PostURL string `json:"post_url"`
}
type User struct {
originPath string
email string
OAuthToken *oauth2.Token `json:"oauth_token"`
Likes []Post `json:"likes"`
}
func (u *User) save() error {
p := u.originPath
if p == "" {
return fmt.Errorf("originPath somehow unset")
}
bs, err := json.Marshal(u)
if err != nil {
return err
}
return os.WriteFile(p, bs, 0600)
}
func pathForUser(email, dataDir string) string {
sum := sha256.Sum256([]byte(email))
return filepath.Join(dataDir, hex.EncodeToString(sum[:])+".json")
}
func loadUser(email, dataDir string) (*User, error) {
p := pathForUser(email, dataDir)
u := &User{originPath: p, email: email}
d, err := os.ReadFile(p)
if errors.Is(err, fs.ErrNotExist) {
log.Printf("generating new user for %v (path= %v)", email, p)
return u, nil
}
if err := json.Unmarshal(d, u); err != nil {
return nil, fmt.Errorf("unmarshalling JSON from %v: %w", p, err)
}
log.Printf("loaded user for %v from %v", email, p)
return u, nil
}
const (
pomeriumEmailHeader = "x-pomerium-claim-email"
oauthStateCookie = "tumblr-oauth-state"
)
type contextKeyType struct{}
var (
userContextKey = contextKeyType{}
)
func user(ctx context.Context) (*User, bool) {
u, ok := ctx.Value(userContextKey).(*User)
return u, ok
}
type app struct {
oauthConfig *oauth2.Config
dataDir string
debugOnlyAssumeEmail string
}
func (a *app) loadUserMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
email := a.debugOnlyAssumeEmail
if email == "" {
email = r.Header.Get(pomeriumEmailHeader)
}
if email == "" {
http.Error(rw, "no email", http.StatusForbidden)
return
}
u, err := loadUser(email, a.dataDir)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), userContextKey, u)))
})
}
func (a *app) performLogin(ctx context.Context, u *User, rw http.ResponseWriter, r *http.Request) {
randData := make([]byte, 32)
if n, err := rand.Read(randData); err != nil {
http.Error(rw, "bad random: "+err.Error(), http.StatusInternalServerError)
return
} else if n != len(randData) {
http.Error(rw, fmt.Sprintf("bad random: got %d of random not %d", n, len(randData)), http.StatusInternalServerError)
return
}
stateVal := hex.EncodeToString(randData)
http.SetCookie(rw, &http.Cookie{
Name: oauthStateCookie,
Value: stateVal,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
http.Redirect(rw, r, a.oauthConfig.AuthCodeURL(stateVal, oauth2.AccessTypeOnline), http.StatusSeeOther)
}
func (a *app) handleOAuth(ctx context.Context, u *User, rw http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(oauthStateCookie)
if err != nil {
http.Error(rw, fmt.Sprintf("getting state cookie: %v", err), http.StatusBadRequest)
return
}
if cookie.Value != r.URL.Query().Get("state") {
http.Error(rw, "invalid state", http.StatusBadRequest)
return
}
tok, err := a.oauthConfig.Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
http.Error(rw, fmt.Sprintf("oauth2 exchange: %v", err), http.StatusInternalServerError)
return
}
u.OAuthToken = tok
if err := u.save(); err != nil {
http.Error(rw, fmt.Sprintf("persisting user: %v", err), http.StatusInternalServerError)
return
}
http.Redirect(rw, r, "/refresh", http.StatusSeeOther)
}
func (a *app) redirectToLogin(rw http.ResponseWriter, r *http.Request) {
http.Redirect(rw, r, "/login", http.StatusSeeOther)
}
func (a *app) requireOAuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
u, ok := user(ctx)
if !ok {
http.Error(rw, "user missing", http.StatusInternalServerError)
} else if r.URL.Path == "/login" {
a.performLogin(ctx, u, rw, r)
} else if r.URL.Path == "/oauth" {
a.handleOAuth(ctx, u, rw, r)
} else if u == nil || (len(u.Likes) == 0 && (u.OAuthToken == nil || !u.OAuthToken.Valid())) {
a.redirectToLogin(rw, r)
} else {
next.ServeHTTP(rw, r)
}
})
}
type likesResponse struct {
Response struct {
LikedPosts []Post `json:"liked_posts"`
Links struct {
Next struct {
Href string `json:"href"`
} `json:"next"`
} `json:"_links"`
} `json:"response"`
}
func (a *app) refreshLikes(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
u, ok := user(ctx)
if !ok || !u.OAuthToken.Valid() {
a.redirectToLogin(rw, r)
return
}
client := a.oauthConfig.Client(ctx, u.OAuthToken)
var likes []Post
const urlPrefix = "https://api.tumblr.com"
urlSuffix := "/v2/user/likes"
for {
req, err := http.NewRequestWithContext(ctx, "GET", urlPrefix+urlSuffix, nil)
if err != nil {
http.Error(rw, fmt.Sprintf("formulating likes request: %v", err), http.StatusInternalServerError)
return
}
fmt.Println(req.URL)
resp, err := client.Do(req)
if err != nil {
http.Error(rw, fmt.Sprintf("performing likes request: %v", err), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
errBody, _ := io.ReadAll(resp.Body)
http.Error(rw, fmt.Sprintf("likes request status %v\n\n%v", resp.StatusCode, string(errBody)), http.StatusInternalServerError)
return
}
var r likesResponse
if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
http.Error(rw, fmt.Sprintf("decoding likes body as json: %v", err), http.StatusInternalServerError)
return
}
likes = append(likes, r.Response.LikedPosts...)
urlSuffix = r.Response.Links.Next.Href
if urlSuffix == "" {
break
}
}
u.Likes = likes
if err := u.save(); err != nil {
http.Error(rw, fmt.Sprintf("saving likes: %v", err), http.StatusInternalServerError)
return
}
http.Redirect(rw, r, "/?refreshed=true", http.StatusSeeOther)
return
}
func (a *app) index(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
u, ok := user(ctx)
if !ok {
a.redirectToLogin(rw, r)
return
}
fmt.Fprintf(rw, `
<!DOCTYPE html>
<h1>%d likes</h1>
<h2><a href="/random">random</a></h2>
<br>
<br>
<br>
<br>
<a href="/refresh">refresh</a>
`, len(u.Likes))
}
func (a *app) random(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
u, ok := user(ctx)
if !ok {
a.redirectToLogin(rw, r)
return
}
if len(u.Likes) == 0 {
http.Redirect(rw, r, "/", http.StatusSeeOther)
return
}
bigN, err := rand.Int(rand.Reader, big.NewInt(int64(len(u.Likes))))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
n := int(bigN.Int64())
l := u.Likes[n]
http.Redirect(rw, r, l.PostURL, http.StatusSeeOther)
}
func main() {
flag.Parse()
log.Println("using oauth client ID", *oauthClientID)
a := &app{
oauthConfig: &oauth2.Config{
ClientID: *oauthClientID,
ClientSecret: *oauthClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: "https://www.tumblr.com/oauth2/authorize",
TokenURL: "https://api.tumblr.com/v2/oauth2/token",
},
Scopes: []string{"basic"},
RedirectURL: fmt.Sprintf("%s/oauth", *baseURL),
},
dataDir: *dataDir,
debugOnlyAssumeEmail: *debugOnlyAssumeEmail,
}
httpMux.HandleFunc("/refresh", a.refreshLikes)
httpMux.HandleFunc("/", a.index)
httpMux.HandleFunc("/random", a.random)
s := &http.Server{
Handler: a.loadUserMiddleware(a.requireOAuthMiddleware(httpMux)),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, singleAddr := range strings.Split(*addr, ",") {
singleAddr := singleAddr
go func() {
l, err := net.Listen("tcp", singleAddr)
if err != nil {
log.Printf("listen on %v: %v", singleAddr, err)
cancel()
return
}
log.Printf("starting to serve on %v", singleAddr)
if err := s.Serve(l); err != http.ErrServerClosed {
log.Printf("HTTP server ListenAndServe: %v", err)
cancel()
}
}()
}
ctx, stop := signal.NotifyContext(ctx, os.Interrupt)
defer stop()
select {
case <-ctx.Done():
log.Println("shutting down gracefully (SIGINT again to stop immediately)")
stop()
if err := s.Shutdown(context.Background()); err != nil {
log.Printf("HTTP server Shutdown: %v", err)
}
}
}