417 lines
11 KiB
Go
417 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"log"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
var (
|
|
debugOnlyAssumeEmail = flag.String("debug_only_assume_email", "", "If set, ignores the x-pomerium-claim-email header and just uses this static address.")
|
|
addr = flag.String("addr", "localhost:10908", "Port to listen on.")
|
|
dataDir = flag.String("data_dir", os.Getenv("STATE_DIRECTORY"), "Data directory to store data in.")
|
|
|
|
oauthClientID = flag.String("oauth_client_id", os.Getenv("OAUTH_CLIENT_ID"), "Sets the OAuth client ID to use.")
|
|
oauthClientSecret = flag.String("oauth_client_secret", os.Getenv("OAUTH_CLIENT_SECRET"), "Sets the OAuth client secret to use.")
|
|
baseURL = flag.String("base_url", "", "Sets my publically-visible base domain.")
|
|
|
|
httpMux = http.NewServeMux()
|
|
)
|
|
|
|
type Post struct {
|
|
PostURL string `json:"post_url"`
|
|
}
|
|
|
|
type User struct {
|
|
originPath string
|
|
email string
|
|
|
|
OAuthToken *oauth2.Token `json:"oauth_token"`
|
|
Likes []Post `json:"likes"`
|
|
}
|
|
|
|
func (u *User) save() error {
|
|
p := u.originPath
|
|
if p == "" {
|
|
return fmt.Errorf("originPath somehow unset")
|
|
}
|
|
|
|
bs, err := json.Marshal(u)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(p, bs, 0600)
|
|
}
|
|
|
|
func pathForUser(email, dataDir string) string {
|
|
sum := sha256.Sum256([]byte(email))
|
|
return filepath.Join(dataDir, hex.EncodeToString(sum[:])+".json")
|
|
}
|
|
|
|
func loadUser(email, dataDir string) (*User, error) {
|
|
p := pathForUser(email, dataDir)
|
|
u := &User{originPath: p, email: email}
|
|
|
|
d, err := os.ReadFile(p)
|
|
if errors.Is(err, fs.ErrNotExist) {
|
|
log.Printf("generating new user for %v (path= %v)", email, p)
|
|
return u, nil
|
|
}
|
|
if err := json.Unmarshal(d, u); err != nil {
|
|
return nil, fmt.Errorf("unmarshalling JSON from %v: %w", p, err)
|
|
}
|
|
log.Printf("loaded user for %v from %v", email, p)
|
|
return u, nil
|
|
}
|
|
|
|
const (
|
|
pomeriumEmailHeader = "x-pomerium-claim-email"
|
|
oauthStateCookie = "tumblr-oauth-state"
|
|
)
|
|
|
|
type contextKeyType struct{}
|
|
|
|
var (
|
|
userContextKey = contextKeyType{}
|
|
)
|
|
|
|
func user(ctx context.Context) (*User, bool) {
|
|
u, ok := ctx.Value(userContextKey).(*User)
|
|
return u, ok
|
|
}
|
|
|
|
type app struct {
|
|
oauthConfig *oauth2.Config
|
|
dataDir string
|
|
debugOnlyAssumeEmail string
|
|
}
|
|
|
|
func (a *app) loadUserMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
email := a.debugOnlyAssumeEmail
|
|
if email == "" {
|
|
email = r.Header.Get(pomeriumEmailHeader)
|
|
}
|
|
if email == "" {
|
|
http.Error(rw, "no email", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
u, err := loadUser(email, a.dataDir)
|
|
if err != nil {
|
|
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), userContextKey, u)))
|
|
})
|
|
}
|
|
|
|
func (a *app) performLogin(ctx context.Context, u *User, rw http.ResponseWriter, r *http.Request) {
|
|
randData := make([]byte, 32)
|
|
if n, err := rand.Read(randData); err != nil {
|
|
http.Error(rw, "bad random: "+err.Error(), http.StatusInternalServerError)
|
|
return
|
|
} else if n != len(randData) {
|
|
http.Error(rw, fmt.Sprintf("bad random: got %d of random not %d", n, len(randData)), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
stateVal := hex.EncodeToString(randData)
|
|
http.SetCookie(rw, &http.Cookie{
|
|
Name: oauthStateCookie,
|
|
Value: stateVal,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
http.Redirect(rw, r, a.oauthConfig.AuthCodeURL(stateVal, oauth2.AccessTypeOnline), http.StatusSeeOther)
|
|
}
|
|
|
|
func (a *app) handleOAuth(ctx context.Context, u *User, rw http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie(oauthStateCookie)
|
|
if err != nil {
|
|
http.Error(rw, fmt.Sprintf("getting state cookie: %v", err), http.StatusBadRequest)
|
|
return
|
|
}
|
|
if cookie.Value != r.URL.Query().Get("state") {
|
|
http.Error(rw, "invalid state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
tok, err := a.oauthConfig.Exchange(ctx, r.URL.Query().Get("code"))
|
|
if err != nil {
|
|
http.Error(rw, fmt.Sprintf("oauth2 exchange: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
u.OAuthToken = tok
|
|
if err := u.save(); err != nil {
|
|
http.Error(rw, fmt.Sprintf("persisting user: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.Redirect(rw, r, "/refresh", http.StatusSeeOther)
|
|
}
|
|
|
|
func (a *app) redirectToLogin(rw http.ResponseWriter, r *http.Request) {
|
|
http.Redirect(rw, r, "/login", http.StatusSeeOther)
|
|
}
|
|
|
|
func (a *app) requireOAuthMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
u, ok := user(ctx)
|
|
if !ok {
|
|
http.Error(rw, "user missing", http.StatusInternalServerError)
|
|
} else if r.URL.Path == "/login" {
|
|
a.performLogin(ctx, u, rw, r)
|
|
} else if r.URL.Path == "/oauth" {
|
|
a.handleOAuth(ctx, u, rw, r)
|
|
} else if u == nil || (len(u.Likes) == 0 && (u.OAuthToken == nil || !u.OAuthToken.Valid())) {
|
|
a.redirectToLogin(rw, r)
|
|
} else {
|
|
next.ServeHTTP(rw, r)
|
|
}
|
|
})
|
|
}
|
|
|
|
type likesResponse struct {
|
|
Response struct {
|
|
LikedPosts []Post `json:"liked_posts"`
|
|
Links struct {
|
|
Next struct {
|
|
Href string `json:"href"`
|
|
} `json:"next"`
|
|
} `json:"_links"`
|
|
} `json:"response"`
|
|
}
|
|
|
|
var (
|
|
refreshBuf map[string][]Post
|
|
refreshBufMu sync.Mutex
|
|
)
|
|
|
|
func (a *app) refreshLikes(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
u, ok := user(ctx)
|
|
if !ok || !u.OAuthToken.Valid() {
|
|
a.redirectToLogin(rw, r)
|
|
return
|
|
}
|
|
|
|
client := a.oauthConfig.Client(ctx, u.OAuthToken)
|
|
|
|
var likes []Post
|
|
const urlPrefix = "https://api.tumblr.com"
|
|
refreshSess := r.FormValue("refresh_session")
|
|
if refreshSess == "" {
|
|
refreshSessBytes := make([]byte, 32)
|
|
if _, err := rand.Read(refreshSessBytes); err != nil {
|
|
http.Error(rw, fmt.Sprintf("generating refresh session token: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
refreshSess = base64.RawURLEncoding.EncodeToString(refreshSessBytes)
|
|
refreshBufMu.Lock()
|
|
_, ok := refreshBuf[refreshSess]
|
|
refreshBufMu.Unlock()
|
|
if ok {
|
|
http.Error(rw, fmt.Sprintf("randomness isn't random enough"), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
} else {
|
|
refreshBufMu.Lock()
|
|
likes2, ok := refreshBuf[refreshSess]
|
|
refreshBufMu.Unlock()
|
|
if !ok {
|
|
http.Error(rw, fmt.Sprintf("refresh session %q is missing", refreshSess), http.StatusBadRequest)
|
|
return
|
|
}
|
|
likes = likes2
|
|
}
|
|
|
|
urlSuffix := r.FormValue("url_suffix")
|
|
if urlSuffix == "" {
|
|
urlSuffix = "/v2/user/likes"
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", urlPrefix+urlSuffix, nil)
|
|
if err != nil {
|
|
http.Error(rw, fmt.Sprintf("formulating likes request: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
fmt.Println(req.URL)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
http.Error(rw, fmt.Sprintf("performing likes request: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
errBody, _ := io.ReadAll(resp.Body)
|
|
http.Error(rw, fmt.Sprintf("likes request status %v\n\n%v", resp.StatusCode, string(errBody)), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
var lr likesResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&lr); err != nil {
|
|
http.Error(rw, fmt.Sprintf("decoding likes body as json: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
likes = append(likes, lr.Response.LikedPosts...)
|
|
|
|
urlSuffix = lr.Response.Links.Next.Href
|
|
if urlSuffix != "" {
|
|
refreshBufMu.Lock()
|
|
refreshBuf[refreshSess] = likes
|
|
refreshBufMu.Unlock()
|
|
fmt.Fprintf(rw, `<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<meta charset="utf-8">
|
|
</head>
|
|
<body>
|
|
<form method="POST" id="theForm">
|
|
<input type="hidden" name="url_suffix" value=%q>
|
|
<input type="hidden" name="refresh_session" value=%q>
|
|
<input type="submit" id="theSubmit">
|
|
</form>
|
|
<script>
|
|
document.querySelector('#theSubmit').disabled = true;
|
|
document.querySelector('#theForm').submit();
|
|
</script>
|
|
</body>
|
|
</html>
|
|
`, urlSuffix, refreshSess)
|
|
return
|
|
}
|
|
refreshBufMu.Lock()
|
|
delete(refreshBuf, refreshSess)
|
|
refreshBufMu.Unlock()
|
|
|
|
u.Likes = likes
|
|
if err := u.save(); err != nil {
|
|
http.Error(rw, fmt.Sprintf("saving likes: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.Redirect(rw, r, "/?refreshed=true", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
func (a *app) index(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
u, ok := user(ctx)
|
|
if !ok {
|
|
a.redirectToLogin(rw, r)
|
|
return
|
|
}
|
|
|
|
fmt.Fprintf(rw, `
|
|
<!DOCTYPE html>
|
|
<h1>%d likes</h1>
|
|
<h2><a href="/random">random</a></h2>
|
|
<br>
|
|
<br>
|
|
<br>
|
|
<br>
|
|
<a href="/refresh">refresh</a>
|
|
`, len(u.Likes))
|
|
}
|
|
|
|
func (a *app) random(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
u, ok := user(ctx)
|
|
if !ok {
|
|
a.redirectToLogin(rw, r)
|
|
return
|
|
}
|
|
if len(u.Likes) == 0 {
|
|
http.Redirect(rw, r, "/", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
bigN, err := rand.Int(rand.Reader, big.NewInt(int64(len(u.Likes))))
|
|
if err != nil {
|
|
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
n := int(bigN.Int64())
|
|
l := u.Likes[n]
|
|
http.Redirect(rw, r, l.PostURL, http.StatusSeeOther)
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
log.Println("using oauth client ID", *oauthClientID)
|
|
a := &app{
|
|
oauthConfig: &oauth2.Config{
|
|
ClientID: *oauthClientID,
|
|
ClientSecret: *oauthClientSecret,
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: "https://www.tumblr.com/oauth2/authorize",
|
|
TokenURL: "https://api.tumblr.com/v2/oauth2/token",
|
|
},
|
|
Scopes: []string{"basic"},
|
|
RedirectURL: fmt.Sprintf("%s/oauth", *baseURL),
|
|
},
|
|
dataDir: *dataDir,
|
|
debugOnlyAssumeEmail: *debugOnlyAssumeEmail,
|
|
}
|
|
httpMux.HandleFunc("/refresh", a.refreshLikes)
|
|
httpMux.HandleFunc("/", a.index)
|
|
httpMux.HandleFunc("/random", a.random)
|
|
|
|
s := &http.Server{
|
|
Handler: a.loadUserMiddleware(a.requireOAuthMiddleware(httpMux)),
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
for _, singleAddr := range strings.Split(*addr, ",") {
|
|
singleAddr := singleAddr
|
|
go func() {
|
|
l, err := net.Listen("tcp", singleAddr)
|
|
if err != nil {
|
|
log.Printf("listen on %v: %v", singleAddr, err)
|
|
cancel()
|
|
return
|
|
}
|
|
log.Printf("starting to serve on %v", singleAddr)
|
|
if err := s.Serve(l); err != http.ErrServerClosed {
|
|
log.Printf("HTTP server ListenAndServe: %v", err)
|
|
cancel()
|
|
}
|
|
}()
|
|
}
|
|
|
|
ctx, stop := signal.NotifyContext(ctx, os.Interrupt)
|
|
defer stop()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Println("shutting down gracefully (SIGINT again to stop immediately)")
|
|
stop()
|
|
if err := s.Shutdown(context.Background()); err != nil {
|
|
log.Printf("HTTP server Shutdown: %v", err)
|
|
}
|
|
}
|
|
}
|