// SPDX-FileCopyrightText: 2020 Luke Granger-Brown <depot@lukegb.com>
//
// SPDX-License-Identifier: Apache-2.0
package main

import (
	"context"
	"encoding/json"
	"flag"
	"fmt"
	"log"
	"net/http"
	"os"
	"os/signal"
	"sync"
	"time"

	"github.com/dghubble/oauth1"
	"github.com/jackc/pgx/v4"
)

var (
	databaseURL  = flag.String("database_url", "", "Database URL")
	tickInterval = flag.Duration("tick_interval", 1*time.Minute, "Tick interval")
	workers      = flag.Int("twitter_workers", 5, "Workers to spawn for fetching info from Twitter API")
)

type user struct {
	UserID                    int64
	Username                  string
	AccessToken, AccessSecret string
	LatestTweet               int64
}

func userAccounts(ctx context.Context, conn *pgx.Conn) ([]user, error) {
	rows, err := conn.Query(ctx, "SELECT userid, username, access_token, access_secret, latest_tweet FROM user_accounts")
	if err != nil {
		return nil, fmt.Errorf("retrieving user accounts: %w", err)
	}
	defer rows.Close()

	var users []user
	for rows.Next() {
		var u user
		err = rows.Scan(&u.UserID, &u.Username, &u.AccessToken, &u.AccessSecret, &u.LatestTweet)
		if err != nil {
			return nil, fmt.Errorf("scanning user account: %w", err)
		}
		users = append(users, u)
	}

	if rows.Err() != nil {
		return nil, fmt.Errorf("fetching user accounts: %w", rows.Err())
	}

	return users, nil
}

type tweet struct {
	ID   int64  `json:"id"`
	Text string `json:"full_text"`

	Data json.RawMessage `json:"-"`
}

func insertTweet(ctx context.Context, conn *pgx.Conn, tw tweet) error {
	_, err := conn.Exec(ctx, "INSERT INTO tweets (id, text, object) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING", tw.ID, tw.Text, tw.Data)
	return err
}

func fetchTweets(ctx context.Context, twitterOAuthConfig *oauth1.Config, user user) ([]tweet, error) {
	log.Printf("Fetching tweets for %s (%d)", user.Username, user.UserID)

	httpClient := twitterOAuthConfig.Client(ctx, oauth1.NewToken(user.AccessToken, user.AccessSecret))

	req, err := http.NewRequest("GET", "https://api.twitter.com/1.1/statuses/home_timeline.json?count=200&include_entities=true&exclude_replies=false&tweet_mode=extended", nil)
	if err != nil {
		return nil, fmt.Errorf("constructing request to Twitter API: %w", err)
	}
	req = req.WithContext(ctx)
	q := req.URL.Query()
	if user.LatestTweet != 0 {
		q.Set("since_id", fmt.Sprintf("%d", user.LatestTweet))
	}
	req.URL.RawQuery = q.Encode()
	resp, err := httpClient.Do(req)
	if err != nil {
		return nil, fmt.Errorf("querying Twitter API: %w", err)
	}
	defer resp.Body.Close()
	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("Twitter API returned status %s", resp.Status)
	}

	var tweetsRaw []json.RawMessage
	if err := json.NewDecoder(resp.Body).Decode(&tweetsRaw); err != nil {
		return nil, fmt.Errorf("decoding JSON from Twitter API: %w", err)
	}

	tweets := make([]tweet, 0, len(tweetsRaw))
	for _, tweetRaw := range tweetsRaw {
		var tw tweet
		if err := json.Unmarshal([]byte(tweetRaw), &tw); err != nil {
			return nil, fmt.Errorf("decoding tweet: %w\n%v", err, string(tweetRaw))
		}
		tw.Data = tweetRaw
		tweets = append(tweets, tw)
	}
	return tweets, nil
}

type userResult struct {
	UserID       int64
	TweetID      int64
	SeenTweetIDs []int64
}

func updateUser(ctx context.Context, conn *pgx.Conn, ur userResult) error {
	var b pgx.Batch
	b.Queue("UPDATE user_accounts SET latest_tweet=$1 WHERE userid=$2", ur.TweetID, ur.UserID)
	for _, twid := range ur.SeenTweetIDs {
		b.Queue("INSERT INTO user_accounts_tweets (userid, tweetid) VALUES ($1, $2) ON CONFLICT DO NOTHING", ur.UserID, twid)
	}
	log.Printf("sending batch of updates for %d...", ur.UserID)
	start := time.Now()
	br := conn.SendBatch(ctx, &b)
	defer br.Close()
	_, err := br.Exec()
	log.Printf("batch of updates for %d done in %s", ur.UserID, time.Now().Sub(start))
	return err
}

func tick(ctx context.Context, conn *pgx.Conn, twitterOAuthConfig *oauth1.Config) error {
	users, err := userAccounts(ctx, conn)
	if err != nil {
		return err
	}

	userCh := make(chan user)
	resultCh := make(chan tweet)
	userResultCh := make(chan userResult)

	wCnt := *workers
	if wCnt > len(users) {
		wCnt = len(users)
	}
	var wg sync.WaitGroup
	wg.Add(wCnt)
	go func() {
		wg.Wait()
		close(resultCh)
		close(userResultCh)
	}()
	for n := 0; n < wCnt; n++ {
		go func() {
			defer wg.Done()

			var largestTweetID int64
			var seenTweetIDs []int64
			var tweetCount int
			for user := range userCh {
				tweets, err := fetchTweets(ctx, twitterOAuthConfig, user)
				if err != nil {
					log.Printf("failed to fetch tweets for %v: %v", user.Username, err)
					continue
				}
				for _, tw := range tweets {
					tweetCount++
					resultCh <- tw
					seenTweetIDs = append(seenTweetIDs, tw.ID)
					if tw.ID > largestTweetID {
						largestTweetID = tw.ID
					}
				}
				log.Printf("ingested tweets for %v (%d tweets, largest ID %d)", user.Username, tweetCount, largestTweetID)
				if largestTweetID != 0 {
					userResultCh <- userResult{
						UserID:       user.UserID,
						TweetID:      largestTweetID,
						SeenTweetIDs: seenTweetIDs,
					}
				}
			}
		}()
	}
	go func() {
		defer close(userCh)
		for _, user := range users {
			userCh <- user
		}
	}()

	func() {
		resultCh, userResultCh := resultCh, userResultCh
		for {
			if resultCh == nil && userResultCh == nil {
				return
			}

			select {
			case tw, ok := <-resultCh:
				if !ok {
					resultCh = nil
					continue
				}
				if err := insertTweet(ctx, conn, tw); err != nil {
					log.Printf("Failed to insert tweet %d: %v\n%v", tw.ID, err, tw.Data)
				}
			case userResult, ok := <-userResultCh:
				if !ok {
					userResultCh = nil
					continue
				}
				if err := updateUser(ctx, conn, userResult); err != nil {
					log.Printf("Failed to update user %d: %v", userResult.UserID, err)
				}
			}
		}
	}()

	return nil
}

func main() {
	flag.Parse()

	ctx := context.Background()

	conn, err := pgx.Connect(ctx, *databaseURL)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err)
		os.Exit(1)
	}
	defer conn.Close(ctx)

	ckey, csecret := os.Getenv("TWITTER_OAUTH_CONSUMER_KEY"), os.Getenv("TWITTER_OAUTH_CONSUMER_SECRET")
	if ckey == "" || csecret == "" {
		fmt.Fprintf(os.Stderr, "No TWITTER_OAUTH_CONSUMER_KEY or TWITTER_OAUTH_CONSUMER_SECRET\n")
		os.Exit(1)
	}
	twitterOAuthConfig := oauth1.NewConfig(ckey, csecret)

	stop := func() <-chan struct{} {
		st := make(chan os.Signal)
		signal.Notify(st, os.Interrupt)
		stopCh := make(chan struct{})
		go func() {
			<-st
			log.Printf("Shutting down")
			close(stopCh)
			signal.Reset(os.Interrupt)
		}()
		return stopCh
	}()

	log.Printf("Started up.")
	log.Printf("Doing initial fetch.")
	if err := tick(ctx, conn, twitterOAuthConfig); err != nil {
		log.Printf("tick failed: %v", err)
	}
	t := time.NewTicker(*tickInterval)
loop:
	for {
		select {
		case <-t.C:
			log.Printf("Tick...")
			if err := tick(ctx, conn, twitterOAuthConfig); err != nil {
				log.Printf("tick failed: %v", err)
			}
		case <-stop:
			break loop
		}
	}
}