// SPDX-FileCopyrightText: 2020 Luke Granger-Brown // // SPDX-License-Identifier: Apache-2.0 package main import ( "context" "encoding/json" "flag" "fmt" "log" "net/http" "os" "os/signal" "sync" "time" "github.com/dghubble/oauth1" "github.com/jackc/pgx/v4" ) var ( databaseURL = flag.String("database_url", "", "Database URL") tickInterval = flag.Duration("tick_interval", 1*time.Minute, "Tick interval") workers = flag.Int("twitter_workers", 5, "Workers to spawn for fetching info from Twitter API") ) type user struct { UserID int64 Username string AccessToken, AccessSecret string LatestTweet int64 } func userAccounts(ctx context.Context, conn *pgx.Conn) ([]user, error) { rows, err := conn.Query(ctx, "SELECT userid, username, access_token, access_secret, latest_tweet FROM user_accounts") if err != nil { return nil, fmt.Errorf("retrieving user accounts: %w", err) } defer rows.Close() var users []user for rows.Next() { var u user err = rows.Scan(&u.UserID, &u.Username, &u.AccessToken, &u.AccessSecret, &u.LatestTweet) if err != nil { return nil, fmt.Errorf("scanning user account: %w", err) } users = append(users, u) } if rows.Err() != nil { return nil, fmt.Errorf("fetching user accounts: %w", rows.Err()) } return users, nil } type tweet struct { ID int64 `json:"id"` Text string `json:"full_text"` Data json.RawMessage `json:"-"` } func insertTweet(ctx context.Context, conn *pgx.Conn, tw tweet) error { _, err := conn.Exec(ctx, "INSERT INTO tweets (id, text, object) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING", tw.ID, tw.Text, tw.Data) return err } func fetchTweets(ctx context.Context, twitterOAuthConfig *oauth1.Config, user user) ([]tweet, error) { log.Printf("Fetching tweets for %s (%d)", user.Username, user.UserID) httpClient := twitterOAuthConfig.Client(ctx, oauth1.NewToken(user.AccessToken, user.AccessSecret)) req, err := http.NewRequest("GET", "https://api.twitter.com/1.1/statuses/home_timeline.json?count=200&include_entities=true&exclude_replies=false&tweet_mode=extended", nil) if err != nil { return nil, fmt.Errorf("constructing request to Twitter API: %w", err) } req = req.WithContext(ctx) q := req.URL.Query() if user.LatestTweet != 0 { q.Set("since_id", fmt.Sprintf("%d", user.LatestTweet)) } req.URL.RawQuery = q.Encode() resp, err := httpClient.Do(req) if err != nil { return nil, fmt.Errorf("querying Twitter API: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("Twitter API returned status %s", resp.Status) } var tweetsRaw []json.RawMessage if err := json.NewDecoder(resp.Body).Decode(&tweetsRaw); err != nil { return nil, fmt.Errorf("decoding JSON from Twitter API: %w", err) } tweets := make([]tweet, 0, len(tweetsRaw)) for _, tweetRaw := range tweetsRaw { var tw tweet if err := json.Unmarshal([]byte(tweetRaw), &tw); err != nil { return nil, fmt.Errorf("decoding tweet: %w\n%v", err, string(tweetRaw)) } tw.Data = tweetRaw tweets = append(tweets, tw) } return tweets, nil } type userResult struct { UserID int64 TweetID int64 SeenTweetIDs []int64 } func updateUser(ctx context.Context, conn *pgx.Conn, ur userResult) error { var b pgx.Batch b.Queue("UPDATE user_accounts SET latest_tweet=$1 WHERE userid=$2", ur.TweetID, ur.UserID) for _, twid := range ur.SeenTweetIDs { b.Queue("INSERT INTO user_accounts_tweets (userid, tweetid) VALUES ($1, $2) ON CONFLICT DO NOTHING", ur.UserID, twid) } log.Printf("sending batch of updates for %d...", ur.UserID) start := time.Now() br := conn.SendBatch(ctx, &b) defer br.Close() _, err := br.Exec() log.Printf("batch of updates for %d done in %s", ur.UserID, time.Now().Sub(start)) return err } func tick(ctx context.Context, conn *pgx.Conn, twitterOAuthConfig *oauth1.Config) error { users, err := userAccounts(ctx, conn) if err != nil { return err } userCh := make(chan user) resultCh := make(chan tweet) userResultCh := make(chan userResult) wCnt := *workers if wCnt > len(users) { wCnt = len(users) } var wg sync.WaitGroup wg.Add(wCnt) go func() { wg.Wait() close(resultCh) close(userResultCh) }() for n := 0; n < wCnt; n++ { go func() { defer wg.Done() var largestTweetID int64 var seenTweetIDs []int64 var tweetCount int for user := range userCh { tweets, err := fetchTweets(ctx, twitterOAuthConfig, user) if err != nil { log.Printf("failed to fetch tweets for %v: %v", user.Username, err) continue } for _, tw := range tweets { tweetCount++ resultCh <- tw seenTweetIDs = append(seenTweetIDs, tw.ID) if tw.ID > largestTweetID { largestTweetID = tw.ID } } log.Printf("ingested tweets for %v (%d tweets, largest ID %d)", user.Username, tweetCount, largestTweetID) if largestTweetID != 0 { userResultCh <- userResult{ UserID: user.UserID, TweetID: largestTweetID, SeenTweetIDs: seenTweetIDs, } } } }() } go func() { defer close(userCh) for _, user := range users { userCh <- user } }() func() { resultCh, userResultCh := resultCh, userResultCh for { if resultCh == nil && userResultCh == nil { return } select { case tw, ok := <-resultCh: if !ok { resultCh = nil continue } if err := insertTweet(ctx, conn, tw); err != nil { log.Printf("Failed to insert tweet %d: %v\n%v", tw.ID, err, tw.Data) } case userResult, ok := <-userResultCh: if !ok { userResultCh = nil continue } if err := updateUser(ctx, conn, userResult); err != nil { log.Printf("Failed to update user %d: %v", userResult.UserID, err) } } } }() return nil } func main() { flag.Parse() ctx := context.Background() conn, err := pgx.Connect(ctx, *databaseURL) if err != nil { fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err) os.Exit(1) } defer conn.Close(ctx) ckey, csecret := os.Getenv("TWITTER_OAUTH_CONSUMER_KEY"), os.Getenv("TWITTER_OAUTH_CONSUMER_SECRET") if ckey == "" || csecret == "" { fmt.Fprintf(os.Stderr, "No TWITTER_OAUTH_CONSUMER_KEY or TWITTER_OAUTH_CONSUMER_SECRET\n") os.Exit(1) } twitterOAuthConfig := oauth1.NewConfig(ckey, csecret) stop := func() <-chan struct{} { st := make(chan os.Signal) signal.Notify(st, os.Interrupt) stopCh := make(chan struct{}) go func() { <-st log.Printf("Shutting down") close(stopCh) signal.Reset(os.Interrupt) }() return stopCh }() log.Printf("Started up.") log.Printf("Doing initial fetch.") if err := tick(ctx, conn, twitterOAuthConfig); err != nil { log.Printf("tick failed: %v", err) } t := time.NewTicker(*tickInterval) loop: for { select { case <-t.C: log.Printf("Tick...") if err := tick(ctx, conn, twitterOAuthConfig); err != nil { log.Printf("tick failed: %v", err) } case <-stop: break loop } } }