package main

import (
	"context"
	"encoding/json"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net/http"
	"os"
	"strconv"
	"time"

	"github.com/dghubble/oauth1"
	"golang.org/x/sync/errgroup"
)

var (
	tweetArchiveFile = flag.String("tweet_archive_file", "", "Tweet archive .js file")
	tweetsFromAPI    = flag.Bool("tweets_from_api", false, "Fetch tweets from API")
	tweetCutoff      = flag.Duration("tweet_cutoff", 3*30*24*time.Hour, "Cutoff lookback (default is 3 'months')")
	dryRun           = flag.Bool("dryrun", true, "Dry run: don't delete anything")

	deleteThreads = flag.Int("delete_threads", 64, "Concurrent threads to use for deleting tweets.")
)

type archivedTweetTime struct {
	time.Time
}

func (t *archivedTweetTime) UnmarshalJSON(b []byte) error {
	var s string
	if err := json.Unmarshal(b, &s); err != nil {
		return err
	}
	pt, err := time.Parse("Mon Jan 02 15:04:05 -0700 2006", s)
	if err != nil {
		return err
	}
	t.Time = pt
	return nil
}

type archivedTweet struct {
	Tweet struct {
		Id        string            `json:"id_str"`
		CreatedAt archivedTweetTime `json:"created_at"`
	} `json:"tweet"`
}

func loadTweetIDsFromFile(ctx context.Context, tweetIDsFile string, tweetCutoff time.Time) ([]int64, error) {
	f, err := os.Open(tweetIDsFile)
	if err != nil {
		return nil, fmt.Errorf("os.Open(%q): %w", tweetIDsFile, err)
	}
	defer f.Close()

	if _, err := io.CopyN(ioutil.Discard, f, int64(len("window.YTD.tweet.part0 = "))); err != nil {
		return nil, fmt.Errorf("io.CopyN: %w", err)
	}

	var ts []archivedTweet
	if err := json.NewDecoder(f).Decode(&ts); err != nil {
		return nil, fmt.Errorf("json Decode: %w", err)
	}

	var out []int64
	for _, t := range ts {
		tID, err := strconv.ParseInt(t.Tweet.Id, 10, 64)
		if err != nil {
			return nil, fmt.Errorf("ParseInt tweet ID %q: %w", t.Tweet.Id, err)
		}

		if t.Tweet.CreatedAt.Before(tweetCutoff) {
			out = append(out, tID)
		}
	}
	return out, nil
}

type apiTweet struct {
	Id        string            `json:"id_str"`
	CreatedAt archivedTweetTime `json:"created_at"`
}

func fetchTweets(ctx context.Context, httpClient *http.Client, maxID int64) ([]apiTweet, error) {
	var suffix string
	if maxID != 0 {
		suffix = fmt.Sprintf("&max_id=%d", (maxID - 1))
	}
	req, err := http.NewRequest("GET", "https://api.twitter.com/1.1/statuses/user_timeline.json?user_id=me&count=200"+suffix, nil)
	if err != nil {
		return nil, fmt.Errorf("http.NewRequest for timeline max tweet %d: %w", maxID, err)
	}
	req = req.WithContext(ctx)

	resp, err := httpClient.Do(req)
	if err != nil {
		return nil, fmt.Errorf("http.Do for maxID %d: %w", maxID, err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("fetching timeline: %w", err)
	}

	var tweets []apiTweet
	if err := json.NewDecoder(resp.Body).Decode(&tweets); err != nil {
		return nil, fmt.Errorf("decoding timeline JSON: %w", err)
	}

	return tweets, nil
}

func loadTweetIDsFromAPI(ctx context.Context, httpClient *http.Client, tweetCutoff time.Time) ([]int64, error) {
	var maxID int64
	var earliestTweet *apiTweet
	var out []int64

	for {
		tweets, err := fetchTweets(ctx, httpClient, maxID)
		if err != nil {
			return nil, fmt.Errorf("fetchTweets(%d): %w", maxID, err)
		}
		log.Printf("fetched tweets with max ID %d", maxID)

		for _, t := range tweets {
			tID, err := strconv.ParseInt(t.Id, 10, 64)
			if err != nil {
				return nil, fmt.Errorf("ParseInt tweet ID %q: %w", t.Id, err)
			}
			if maxID == 0 || tID < maxID {
				maxID = tID
			}

			if t.CreatedAt.Before(tweetCutoff) {
				out = append(out, tID)
			}
			if earliestTweet == nil || earliestTweet.CreatedAt.Time.After(t.CreatedAt.Time) {
				earliestTweet = &t
			}
		}
		if len(tweets) == 0 {
			break
		}
	}
	log.Printf("Earliest tweet: %v %v (%v ago)", earliestTweet.Id, earliestTweet.CreatedAt.Time, time.Since(earliestTweet.CreatedAt.Time))
	return out, nil
}

func deleteTweet(ctx context.Context, httpClient *http.Client, tid int64) error {
	req, err := http.NewRequest("POST", fmt.Sprintf("https://api.twitter.com/1.1/statuses/destroy/%d.json?trim_user=true", tid), nil)
	if err != nil {
		return fmt.Errorf("http.NewRequest for tid %d: %w", tid, err)
	}
	req = req.WithContext(ctx)

	resp, err := httpClient.Do(req)
	if err != nil {
		return fmt.Errorf("http.Do for tid %d: %w", tid, err)
	}
	defer resp.Body.Close()

	switch resp.StatusCode {
	case 200, 404:
		break
	default:
		return fmt.Errorf("deleting tid %d: got status %d %q", resp.StatusCode, resp.Status)
	}

	if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
		return fmt.Errorf("io.Copy for tid %d: %w", tid, err)
	}

	return nil
}

func deleteTweets(ctx context.Context, httpClient *http.Client, deleteThreads int, tweetIDs []int64) error {
	deleteQueue := make(chan int64)

	eg, egCtx := errgroup.WithContext(ctx)
	for n := 0; n < deleteThreads; n++ {
		eg.Go(func() error {
			ctx := egCtx
			for tid := range deleteQueue {
				if err := deleteTweet(ctx, httpClient, tid); err != nil {
					return fmt.Errorf("deleting tweet %d: %w", tid, err)
				}
			}
			return nil
		})
	}
	eg.Go(func() error {
		defer close(deleteQueue)
		for n, tid := range tweetIDs {
			if n%100 == 0 {
				log.Printf("Progress: %d / %d", n, len(tweetIDs))
			}
			select {
			case <-egCtx.Done():
				return nil
			case deleteQueue <- tid:
			}
		}
		return nil
	})

	return eg.Wait()
}

func main() {
	flag.Parse()

	ctx := context.Background()
	log.Printf("Started up.")

	ckey, csecret := os.Getenv("TWITTER_OAUTH_CONSUMER_KEY"), os.Getenv("TWITTER_OAUTH_CONSUMER_SECRET")
	if ckey == "" || csecret == "" {
		fmt.Fprintf(os.Stderr, "No TWITTER_OAUTH_CONSUMER_KEY or TWITTER_OAUTH_CONSUMER_SECRET\n")
		os.Exit(1)
	}
	atoken, asecret := os.Getenv("TWITTER_OAUTH_ACCESS_TOKEN"), os.Getenv("TWITTER_OAUTH_ACCESS_SECRET")
	if atoken == "" || asecret == "" {
		fmt.Fprintf(os.Stderr, "No TWITTER_OAUTH_ACCESS_TOKEN or TWITTER_OAUTH_ACCESS_SECRET\n")
		os.Exit(1)
	}
	httpClient := oauth1.NewConfig(ckey, csecret).Client(ctx, oauth1.NewToken(atoken, asecret))
	log.Printf("Initialized OAuth config.")

	var tweetIDs []int64

	if *tweetsFromAPI {
		log.Printf("Loading tweets from API.")
		var err error
		tweetIDs, err = loadTweetIDsFromAPI(ctx, httpClient, time.Now().Add(-*tweetCutoff))
		if err != nil {
			log.Fatalf("loadTweetIDsFromAPI(%q): %v", *tweetCutoff, err)
		}
	} else {
		if *tweetArchiveFile == "" {
			log.Fatalf("--tweet_archive_file must be set.")
		}
		var err error
		tweetIDs, err = loadTweetIDsFromFile(ctx, *tweetArchiveFile, time.Now().Add(-*tweetCutoff))
		if err != nil {
			log.Fatalf("loadTweetIDsFromFile(%q, %q): %v", *tweetArchiveFile, *tweetCutoff, err)
		}
	}
	log.Printf("Got %d tweets to delete.", len(tweetIDs))

	if *dryRun {
		log.Printf("Dry run: not doing anything.")
		return
	}
	if err := deleteTweets(ctx, httpClient, *deleteThreads, tweetIDs); err != nil {
		log.Fatalf("deleteTweets: %v", err)
	}
}