// SPDX-FileCopyrightText: 2020 Luke Granger-Brown // // SPDX-License-Identifier: Apache-2.0 package main import ( "context" "encoding/json" "flag" "fmt" "io" "io/ioutil" "log" "net/http" "net/url" "os" "os/signal" "path/filepath" "strings" "sync" "time" "github.com/dghubble/oauth1" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" ) var ( databaseURL = flag.String("database_url", "", "Database URL") tickInterval = flag.Duration("tick_interval", 1*time.Minute, "Tick interval") workerCount = flag.Int("workers", 20, "Workers to spawn for fetching info from Twitter API") workAtOnce = flag.Int("work_at_once", 40, "Work to enqueue at once (i.e. we will only look for this much work to do).") mediaFetchDestination = flag.String("media_fetch_destination", "/tmp", "Path to write fetched media to.") mediaTickInterval = flag.Duration("media_tick_interval", 0, "Tick interval for media - takes priority over --tick_interval if non-zero.") mediaWorkAtOnce = flag.Int("media_work_at_once", 0, "Work to enqueue at once for media - takes priority over --work_at_once if non-zero.") relatedTickInterval = flag.Duration("related_tick_interval", 0, "Tick interval for related - takes priority over --tick_interval if non-zero.") relatedWorkAtOnce = flag.Int("related_work_at_once", 0, "Work to enqueue at once for related - takes priority over --work_at_once if non-zero.") ) type WorkerConfig struct { Name string TickInterval time.Duration OAuthConfig *oauth1.Config WorkAtOnce int DatabasePool *pgxpool.Pool WorkQueue chan func(context.Context) } func (wc WorkerConfig) Merge(other WorkerConfig) WorkerConfig { if other.Name != "" { wc.Name = other.Name } if other.TickInterval > 0 { wc.TickInterval = other.TickInterval } if other.OAuthConfig != nil { wc.OAuthConfig = other.OAuthConfig } if other.WorkAtOnce != 0 { wc.WorkAtOnce = other.WorkAtOnce } if other.DatabasePool != nil { wc.DatabasePool = other.DatabasePool } if other.WorkQueue != nil { wc.WorkQueue = other.WorkQueue } return wc } type worker func(context.Context, WorkerConfig) error type tweetVideoVariant struct { Bitrate int `json:"bitrate"` ContentType string `json:"content_type"` URL string `json:"url"` } type tweetVideoInfo struct { Variants []tweetVideoVariant `json:"variants"` } type tweetExtendedMedia struct { ID int64 `json:"id"` Type string `json:"type"` MediaURLHTTPS string `json:"media_url_https"` VideoInfo *tweetVideoInfo `json:"video_info"` } type tweetExtendedEntities struct { Media []tweetExtendedMedia `json:"media"` } type tweetData struct { ID int64 `json:"id"` Text string `json:"full_text"` QuotedStatus *tweetData `json:"quoted_status"` RetweetedStatus *tweetData `json:"retweeted_status"` InReplyToStatusID *int64 `json:"in_reply_to_status_id"` ExtendedEntities tweetExtendedEntities `json:"extended_entities"` Data json.RawMessage `json:"-"` } func mediaFetchOutputPath(u string) (string, error) { p, err := url.Parse(u) if err != nil { return "", err } path := strings.TrimLeft(p.Path, "/") return filepath.Join(*mediaFetchDestination, p.Host, path), nil } type databaseQuerier interface { Query(context.Context, string, ...interface{}) (pgx.Rows, error) } func fetchTweets(ctx context.Context, conn databaseQuerier, query string, args ...interface{}) ([]tweetData, error) { rows, err := conn.Query(ctx, query, args...) if err != nil { return nil, fmt.Errorf("querying: %w", err) } var tweets []tweetData for rows.Next() { var tweetJSON pgtype.JSONB if err := rows.Scan(&tweetJSON); err != nil { return nil, fmt.Errorf("scanning: %w", err) } var tweet tweetData if err := json.Unmarshal(tweetJSON.Bytes, &tweet); err != nil { return nil, fmt.Errorf("unmarshaling: %w", err) } tweets = append(tweets, tweet) } if rows.Err() != nil { return nil, fmt.Errorf("retrieving: %w", err) } return tweets, nil } func mediaFetchTick(ctx context.Context, cfg WorkerConfig) error { tweets, err := fetchTweets(ctx, cfg.DatabasePool, "SELECT object FROM tweets WHERE NOT fetched_media LIMIT $1", cfg.WorkAtOnce) if err != nil { return fmt.Errorf("fetching unfetched media tweet IDs: %w", err) } var ids []int64 for _, t := range tweets { ids = append(ids, t.ID) } log.Printf("[%v] Got %d tweets (%v)", cfg.Name, len(tweets), ids) var wg sync.WaitGroup wg.Add(len(tweets)) for _, t := range tweets { t := t cfg.WorkQueue <- func(ctx context.Context) { defer wg.Done() // Abbreviated message for tweets with no media, which is the majority of them. if len(t.ExtendedEntities.Media) == 0 { if _, err := cfg.DatabasePool.Exec(ctx, "UPDATE tweets SET fetched_media=true WHERE id=$1", t.ID); err != nil { log.Printf("[%v:%d] Failed to update tweets table: %v", cfg.Name, t.ID, err) } log.Printf("[%v:%d] No media", cfg.Name, t.ID) return } log.Printf("[%v:%d] Starting work", cfg.Name, t.ID) defer log.Printf("[%v:%d] Done with work", cfg.Name, t.ID) var mediaURLs []string for _, m := range t.ExtendedEntities.Media { if m.MediaURLHTTPS != "" { mediaURLs = append(mediaURLs, m.MediaURLHTTPS) } if m.VideoInfo != nil { bestBitrate := -1 var bestURL string for _, v := range m.VideoInfo.Variants { if v.ContentType == "application/x-mpegURL" { // m3u8 isn't very interesting since we'd have to parse it to get a playlist. continue } if v.Bitrate > bestBitrate { bestBitrate = v.Bitrate bestURL = v.URL } } if bestURL != "" { mediaURLs = append(mediaURLs, bestURL) } } } log.Printf("[%v:%d] %d media URLs to fetch: %v", cfg.Name, t.ID, len(mediaURLs), mediaURLs) for _, u := range mediaURLs { out, err := mediaFetchOutputPath(u) if err != nil { log.Printf("[%v:%d:%v] fetching output path: %v", cfg.Name, t.ID, u, err) return } // Fetch u -> out if fi, err := os.Stat(out); err == nil { // File already exists, does it have non-zero size? if fi.Size() > 0 { // It does! log.Printf("[%v:%d:%v] %v already exists; skipping", cfg.Name, t.ID, u, out) continue } log.Printf("[%v:%d:%v] %v already exists but has zero size; refetching", cfg.Name, t.ID, u, out) } else if !os.IsNotExist(err) { log.Printf("[%v:%d:%v] checking output path %v: %v", cfg.Name, t.ID, u, out, err) return } if err := os.MkdirAll(filepath.Dir(out), 0755); err != nil { log.Printf("[%v:%d:%v] creating output path %v: %v", cfg.Name, t.ID, u, out, err) return } if err := func() error { resp, err := http.Get(u) if err != nil { return fmt.Errorf("http.Get: %w", err) } defer resp.Body.Close() switch resp.StatusCode { case http.StatusOK: break // This is what we expect. case http.StatusNotFound: log.Printf("[%v:%d:%v] got a 404; marking as successful since this has probably been deleted.", cfg.Name, t.ID, u) return nil default: return fmt.Errorf("HTTP status %d - %v", resp.StatusCode, resp.Status) } f, err := os.Create(out) if err != nil { return err } defer f.Close() written, err := io.Copy(f, resp.Body) if err != nil { // Try to delete the file we created to force a refetch next time. _ = f.Truncate(0) // Ignore the error, though. return fmt.Errorf("IO error: %w", err) } log.Printf("[%v:%d:%v] downloaded %d bytes.", cfg.Name, t.ID, u, written) return nil }(); err != nil { log.Printf("[%v:%d:%v] %v", cfg.Name, t.ID, u, err) return } } log.Printf("[%v:%d] Done downloading media, marking as complete.", cfg.Name, t.ID) if _, err := cfg.DatabasePool.Exec(ctx, "UPDATE tweets SET fetched_media=true WHERE id=$1", t.ID); err != nil { log.Printf("[%v:%d] Failed to update tweets table: %v", cfg.Name, t.ID, err) } } } wg.Wait() log.Printf("[%v] Done with work.", cfg.Name) return nil } func relatedFetchTick(ctx context.Context, cfg WorkerConfig) error { tweets, err := fetchTweets(ctx, cfg.DatabasePool, "SELECT object FROM tweets WHERE NOT fetched_related_tweets LIMIT $1", cfg.WorkAtOnce) if err != nil { return fmt.Errorf("fetching unfetched related tweet IDs: %w", err) } var ids []int64 for _, t := range tweets { ids = append(ids, t.ID) } log.Printf("[%v] Got %d tweets (%v)", cfg.Name, len(tweets), ids) var wg sync.WaitGroup for _, t := range tweets { var tweetIDsToFetch []int64 if t.QuotedStatus != nil { tweetIDsToFetch = append(tweetIDsToFetch, t.QuotedStatus.ID) } if t.RetweetedStatus != nil { tweetIDsToFetch = append(tweetIDsToFetch, t.RetweetedStatus.ID) } if t.InReplyToStatusID != nil { tweetIDsToFetch = append(tweetIDsToFetch, *t.InReplyToStatusID) } log.Printf("[%v:%d] Got %d tweets to fetch (%v)", cfg.Name, t.ID, len(tweetIDsToFetch), tweetIDsToFetch) if len(tweetIDsToFetch) == 0 { if _, err := cfg.DatabasePool.Exec(ctx, "UPDATE tweets SET fetched_related_tweets=true WHERE id=$1", t.ID); err != nil { log.Printf("[%v:%d] Updating tweet to mark as completed: %v", cfg.Name, t.ID, err) } else { log.Printf("[%v:%d] Successfully marked as done (no tweets to fetch).", cfg.Name, t.ID) } continue } wg.Add(1) t := t cfg.WorkQueue <- func(ctx context.Context) { defer wg.Done() // Fetch user credentials for this. // TODO(lukegb): maybe this needs to be smarter (e.g. for protected accounts?), but for now we just pick someone who could see the original tweet. var accessToken, accessSecret string if err := cfg.DatabasePool.QueryRow(ctx, "SELECT ua.access_token, ua.access_secret FROM user_accounts ua INNER JOIN user_accounts_tweets uat ON uat.userid=ua.userid WHERE uat.tweetid=$1 LIMIT 1", t.ID).Scan(&accessToken, &accessSecret); err != nil { log.Printf("[%v:%d] Retrieving OAuth user credentials: %v", cfg.Name, t.ID, err) return } httpClient := cfg.OAuthConfig.Client(ctx, oauth1.NewToken(accessToken, accessSecret)) for _, tid := range tweetIDsToFetch { // Check if we already have tid. log.Printf("[%v:%d] Fetching %d", cfg.Name, t.ID, tid) var cnt int if err := cfg.DatabasePool.QueryRow(ctx, "SELECT COUNT(1) FROM tweets WHERE id=$1", tid).Scan(&cnt); err != nil { log.Printf("[%v:%d:%d] Existence check failed: %v", cfg.Name, t.ID, tid, err) return } if cnt > 0 { log.Printf("[%v:%d] Already have %d", cfg.Name, t.ID, tid) continue } // We don't have it already; let's fetch it from ze Twitterz. req, err := http.NewRequest("GET", fmt.Sprintf("https://api.twitter.com/1.1/statuses/show.json?id=%d&include_entities=true&include_ext_alt_text=true", tid), nil) if err != nil { log.Printf("[%v:%d:%d] Constructing GET request: %v", cfg.Name, t.ID, tid, err) return } resp, err := httpClient.Do(req.WithContext(ctx)) if err != nil { log.Printf("[%v:%d:%d] Executing GET request: %v", cfg.Name, t.ID, tid, err) return } defer resp.Body.Close() switch resp.StatusCode { case http.StatusOK: // Expected. break case http.StatusNotFound: // Tweet has probably been deleted, or is otherwise ACLed. log.Printf("[%v:%d:%d] Not found response; assuming that it's been deleted and treating as a success.", cfg.Name, t.ID, tid) continue default: log.Printf("[%v:%d:%d] GET statuses/show.json returned %d - %v", cfg.Name, t.ID, tid, resp.StatusCode, resp.Status) return } tweetBytes, err := ioutil.ReadAll(resp.Body) if err != nil { log.Printf("[%v:%d:%d] ReadAll: %v", cfg.Name, t.ID, tid, err) return } var subtweet tweetData if err := json.Unmarshal(tweetBytes, &subtweet); err != nil { log.Printf("[%v:%d:%d] Decoding tweet as JSON: %v", cfg.Name, t.ID, tid, err) return } // Insert it into the database. if _, err := cfg.DatabasePool.Exec(ctx, "INSERT INTO tweets (id, text, object) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING", subtweet.ID, subtweet.Text, tweetBytes); err != nil { log.Printf("[%v:%d:%d] Inserting tweet into database: %v", cfg.Name, t.ID, tid, err) return } // Insert the ACLs. if _, err := cfg.DatabasePool.Exec(ctx, "INSERT INTO user_accounts_tweets SELECT userid, $1 tweetid, false on_timeline FROM user_accounts_tweets WHERE tweetid=$2 ON CONFLICT DO NOTHING", tid, t.ID); err != nil { log.Printf("[%v:%d:%d] Inserting ACLs into table failed: %v", cfg.Name, t.ID, tid, err) return } } // We have fetched all the subtweets, mark the tweet as complete. if _, err := cfg.DatabasePool.Exec(ctx, "UPDATE tweets SET fetched_related_tweets=true WHERE id=$1", t.ID); err != nil { log.Printf("[%v:%d] Updating tweet to mark as completed: %v", cfg.Name, t.ID, err) } else { log.Printf("[%v:%d] Successfully processed.", cfg.Name, t.ID) } } } wg.Wait() log.Printf("[%v] Done with work.", cfg.Name) return nil } func tickToWorker(w worker) worker { return func(ctx context.Context, cfg WorkerConfig) error { log.Printf("[%v] Performing initial tick.", cfg.Name) if err := w(ctx, cfg); err != nil { log.Printf("[%v] Initial tick failed: %v", cfg.Name, err) } t := time.NewTicker(cfg.TickInterval) defer t.Stop() for { select { case <-ctx.Done(): return nil case <-t.C: log.Printf("[%v] Tick...", cfg.Name) if err := w(ctx, cfg); err != nil { log.Printf("[%v] Tick failed: %v", cfg.Name, err) } } } } } var ( mediaFetchWorker = tickToWorker(mediaFetchTick) relatedFetchWorker = tickToWorker(relatedFetchTick) ) func main() { flag.Parse() ctx := context.Background() conn, err := pgxpool.Connect(ctx, *databaseURL) if err != nil { fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err) os.Exit(1) } defer conn.Close() ckey, csecret := os.Getenv("TWITTER_OAUTH_CONSUMER_KEY"), os.Getenv("TWITTER_OAUTH_CONSUMER_SECRET") if ckey == "" || csecret == "" { fmt.Fprintf(os.Stderr, "No TWITTER_OAUTH_CONSUMER_KEY or TWITTER_OAUTH_CONSUMER_SECRET\n") os.Exit(1) } twitterOAuthConfig := oauth1.NewConfig(ckey, csecret) stop := func() <-chan struct{} { st := make(chan os.Signal) signal.Notify(st, os.Interrupt) stopCh := make(chan struct{}) go func() { <-st log.Printf("Shutting down") close(stopCh) signal.Reset(os.Interrupt) }() return stopCh }() cctx, cancel := context.WithCancel(ctx) var wg sync.WaitGroup workQueue := make(chan func(context.Context)) wg.Add(*workerCount) for n := 0; n < *workerCount; n++ { go func() { ctx := cctx defer wg.Done() for { select { case <-ctx.Done(): return case f := <-workQueue: f(ctx) } } }() } workers := map[string]worker{ "media": mediaFetchWorker, "related": relatedFetchWorker, } wg.Add(len(workers)) cfg := WorkerConfig{ TickInterval: *tickInterval, OAuthConfig: twitterOAuthConfig, WorkAtOnce: *workAtOnce, DatabasePool: conn, WorkQueue: workQueue, } configs := map[string]WorkerConfig{ "media": cfg.Merge(WorkerConfig{ Name: "media", TickInterval: *mediaTickInterval, WorkAtOnce: *mediaWorkAtOnce, }), "related": cfg.Merge(WorkerConfig{ Name: "related", TickInterval: *relatedTickInterval, WorkAtOnce: *relatedWorkAtOnce, }), } for wn, w := range workers { wn, w := wn, w wcfg, ok := configs[wn] if !ok { wcfg = cfg } go func() { defer wg.Done() log.Printf("Starting worker %v.", wn) if err := w(cctx, wcfg); err != nil { log.Printf("[%v] %v", wn, err) } }() } <-stop log.Printf("Got stop signal. Shutting down...") cancel() wg.Wait() }