260 lines
6.3 KiB
Go
260 lines
6.3 KiB
Go
// SPDX-FileCopyrightText: 2020 Luke Granger-Brown <depot@lukegb.com>
|
|
//
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/dghubble/oauth1"
|
|
"github.com/jackc/pgx/v4"
|
|
)
|
|
|
|
var (
|
|
databaseURL = flag.String("database_url", "", "Database URL")
|
|
tickInterval = flag.Duration("tick_interval", 1*time.Minute, "Tick interval")
|
|
workers = flag.Int("twitter_workers", 5, "Workers to spawn for fetching info from Twitter API")
|
|
)
|
|
|
|
type user struct {
|
|
UserID int64
|
|
Username string
|
|
AccessToken, AccessSecret string
|
|
LatestTweet int64
|
|
}
|
|
|
|
func userAccounts(ctx context.Context, conn *pgx.Conn) ([]user, error) {
|
|
rows, err := conn.Query(ctx, "SELECT userid, username, access_token, access_secret, latest_tweet FROM user_accounts")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("retrieving user accounts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var users []user
|
|
for rows.Next() {
|
|
var u user
|
|
err = rows.Scan(&u.UserID, &u.Username, &u.AccessToken, &u.AccessSecret, &u.LatestTweet)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("scanning user account: %w", err)
|
|
}
|
|
users = append(users, u)
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
return nil, fmt.Errorf("fetching user accounts: %w", rows.Err())
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
type tweet struct {
|
|
ID int64 `json:"id"`
|
|
Text string `json:"full_text"`
|
|
|
|
Data json.RawMessage `json:"-"`
|
|
}
|
|
|
|
func insertTweet(ctx context.Context, conn *pgx.Conn, tw tweet) error {
|
|
_, err := conn.Exec(ctx, "INSERT INTO tweets (id, text, object) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING", tw.ID, tw.Text, tw.Data)
|
|
return err
|
|
}
|
|
|
|
func fetchTweets(ctx context.Context, twitterOAuthConfig *oauth1.Config, user user) ([]tweet, error) {
|
|
log.Printf("Fetching tweets for %s (%d)", user.Username, user.UserID)
|
|
|
|
httpClient := twitterOAuthConfig.Client(ctx, oauth1.NewToken(user.AccessToken, user.AccessSecret))
|
|
|
|
req, err := http.NewRequest("GET", "https://api.twitter.com/1.1/statuses/home_timeline.json?count=200&include_entities=true&exclude_replies=false&tweet_mode=extended", nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("constructing request to Twitter API: %w", err)
|
|
}
|
|
req = req.WithContext(ctx)
|
|
q := req.URL.Query()
|
|
if user.LatestTweet != 0 {
|
|
q.Set("since_id", fmt.Sprintf("%d", user.LatestTweet))
|
|
}
|
|
req.URL.RawQuery = q.Encode()
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying Twitter API: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("Twitter API returned status %s", resp.Status)
|
|
}
|
|
|
|
var tweetsRaw []json.RawMessage
|
|
if err := json.NewDecoder(resp.Body).Decode(&tweetsRaw); err != nil {
|
|
return nil, fmt.Errorf("decoding JSON from Twitter API: %w", err)
|
|
}
|
|
|
|
tweets := make([]tweet, 0, len(tweetsRaw))
|
|
for _, tweetRaw := range tweetsRaw {
|
|
var tw tweet
|
|
if err := json.Unmarshal([]byte(tweetRaw), &tw); err != nil {
|
|
return nil, fmt.Errorf("decoding tweet: %w\n%v", err, string(tweetRaw))
|
|
}
|
|
tw.Data = tweetRaw
|
|
tweets = append(tweets, tw)
|
|
}
|
|
return tweets, nil
|
|
}
|
|
|
|
type userResult struct {
|
|
UserID int64
|
|
TweetID int64
|
|
}
|
|
|
|
func updateUser(ctx context.Context, conn *pgx.Conn, ur userResult) error {
|
|
_, err := conn.Exec(ctx, "UPDATE user_accounts SET latest_tweet=$1 WHERE userid=$2", ur.TweetID, ur.UserID)
|
|
return err
|
|
}
|
|
|
|
func tick(ctx context.Context, conn *pgx.Conn, twitterOAuthConfig *oauth1.Config) error {
|
|
users, err := userAccounts(ctx, conn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
userCh := make(chan user)
|
|
resultCh := make(chan tweet)
|
|
userResultCh := make(chan userResult)
|
|
|
|
wCnt := *workers
|
|
if wCnt > len(users) {
|
|
wCnt = len(users)
|
|
}
|
|
var wg sync.WaitGroup
|
|
wg.Add(wCnt)
|
|
go func() {
|
|
wg.Wait()
|
|
close(resultCh)
|
|
close(userResultCh)
|
|
}()
|
|
for n := 0; n < wCnt; n++ {
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
var largestTweetID int64
|
|
var tweetCount int
|
|
for user := range userCh {
|
|
tweets, err := fetchTweets(ctx, twitterOAuthConfig, user)
|
|
if err != nil {
|
|
log.Printf("failed to fetch tweets for %v: %v", user.Username, err)
|
|
continue
|
|
}
|
|
for _, tw := range tweets {
|
|
tweetCount++
|
|
resultCh <- tw
|
|
if tw.ID > largestTweetID {
|
|
largestTweetID = tw.ID
|
|
}
|
|
}
|
|
log.Printf("ingested tweets for %v (%d tweets, largest ID %d)", user.Username, tweetCount, largestTweetID)
|
|
if largestTweetID != 0 {
|
|
userResultCh <- userResult{
|
|
UserID: user.UserID,
|
|
TweetID: largestTweetID,
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
go func() {
|
|
defer close(userCh)
|
|
for _, user := range users {
|
|
userCh <- user
|
|
}
|
|
}()
|
|
|
|
func() {
|
|
resultCh, userResultCh := resultCh, userResultCh
|
|
for {
|
|
if resultCh == nil && userResultCh == nil {
|
|
return
|
|
}
|
|
|
|
select {
|
|
case tw, ok := <-resultCh:
|
|
if !ok {
|
|
resultCh = nil
|
|
continue
|
|
}
|
|
if err := insertTweet(ctx, conn, tw); err != nil {
|
|
log.Printf("Failed to insert tweet %d: %v\n%v", tw.ID, err, tw.Data)
|
|
}
|
|
case userResult, ok := <-userResultCh:
|
|
if !ok {
|
|
userResultCh = nil
|
|
continue
|
|
}
|
|
if err := updateUser(ctx, conn, userResult); err != nil {
|
|
log.Printf("Failed to update user %d: %v", userResult.UserID, err)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
ctx := context.Background()
|
|
|
|
conn, err := pgx.Connect(ctx, *databaseURL)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer conn.Close(ctx)
|
|
|
|
ckey, csecret := os.Getenv("TWITTER_OAUTH_CONSUMER_KEY"), os.Getenv("TWITTER_OAUTH_CONSUMER_SECRET")
|
|
if ckey == "" || csecret == "" {
|
|
fmt.Fprintf(os.Stderr, "No TWITTER_OAUTH_CONSUMER_KEY or TWITTER_OAUTH_CONSUMER_SECRET\n")
|
|
os.Exit(1)
|
|
}
|
|
twitterOAuthConfig := oauth1.NewConfig(ckey, csecret)
|
|
|
|
stop := func() <-chan struct{} {
|
|
st := make(chan os.Signal)
|
|
signal.Notify(st, os.Interrupt)
|
|
stopCh := make(chan struct{})
|
|
go func() {
|
|
<-st
|
|
log.Printf("Shutting down")
|
|
close(stopCh)
|
|
signal.Reset(os.Interrupt)
|
|
}()
|
|
return stopCh
|
|
}()
|
|
|
|
log.Printf("Started up.")
|
|
log.Printf("Doing initial fetch.")
|
|
if err := tick(ctx, conn, twitterOAuthConfig); err != nil {
|
|
log.Printf("tick failed: %v", err)
|
|
}
|
|
t := time.NewTicker(*tickInterval)
|
|
loop:
|
|
for {
|
|
select {
|
|
case <-t.C:
|
|
log.Printf("Tick...")
|
|
if err := tick(ctx, conn, twitterOAuthConfig); err != nil {
|
|
log.Printf("tick failed: %v", err)
|
|
}
|
|
case <-stop:
|
|
break loop
|
|
}
|
|
}
|
|
}
|