From ecbd80352fc1e1fa7454edac891b78032049edf7 Mon Sep 17 00:00:00 2001 From: Ted Unangst Date: Tue, 26 Nov 2019 00:29:31 -0500 Subject: [PATCH] when sending updates or deletes, send to all fetch recipients too --- activity.go | 3 +++ database.go | 2 ++ web.go | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/activity.go b/activity.go index 684a208..9597525 100644 --- a/activity.go +++ b/activity.go @@ -1209,6 +1209,9 @@ func honkworldwide(user *WhatAbout, honk *Honk) { rcpts[h.XID] = true } } + for _, f := range getbacktracks(honk.XID) { + rcpts[f] = true + } } for a := range rcpts { go deliverate(0, user.ID, a, msg) diff --git a/database.go b/database.go index 1af843c..f0ee4ba 100644 --- a/database.go +++ b/database.go @@ -693,6 +693,7 @@ var stmtGetZonkers, stmtRecentHonkers, stmtGetXonker, stmtSaveXonker, stmtDelete var stmtAllOnts, stmtSaveOnt, stmtUpdateFlags, stmtClearFlags *sql.Stmt var stmtHonksForUserFirstClass, stmtSaveMeta, stmtDeleteMeta, stmtUpdateHonk *sql.Stmt var stmtHonksISaved, stmtGetFilters, stmtSaveFilter, stmtDeleteFilter *sql.Stmt +var stmtGetTracks *sql.Stmt func preparetodie(db *sql.DB, s string) *sql.Stmt { stmt, err := db.Prepare(s) @@ -769,4 +770,5 @@ func prepareStatements(db *sql.DB) { stmtGetFilters = preparetodie(db, "select hfcsid, json from hfcs where userid = ?") stmtSaveFilter = preparetodie(db, "insert into hfcs (userid, json) values (?, ?)") stmtDeleteFilter = preparetodie(db, "delete from hfcs where userid = ? and hfcsid = ?") + stmtGetTracks = preparetodie(db, "select fetches from tracks where xid = ?") } diff --git a/web.go b/web.go index 5f2c52f..c7bbd97 100644 --- a/web.go +++ b/web.go @@ -890,6 +890,33 @@ type Track struct { who string } +func getbacktracks(xid string) []string { + c := make(chan bool) + dumptracks <- c + <-c + row := stmtGetTracks.QueryRow(xid) + var rawtracks string + err := row.Scan(&rawtracks) + if err != nil { + if err != sql.ErrNoRows { + log.Printf("error scanning tracks: %s", err) + } + return nil + } + var rcpts []string + for _, f := range strings.Split(rawtracks, " ") { + idx := strings.LastIndexByte(f, '#') + if idx != -1 { + f = f[:idx] + } + if !strings.HasPrefix(f, "https://") { + f = fmt.Sprintf("%https://%s/inbox", f) + } + rcpts = append(rcpts, f) + } + return rcpts +} + func savetracks(tracks map[string][]string) { db := opendatabase() tx, err := db.Begin() @@ -940,6 +967,7 @@ func savetracks(tracks map[string][]string) { } var trackchan = make(chan Track) +var dumptracks = make(chan chan bool) func tracker() { timeout := 4 * time.Minute @@ -955,6 +983,11 @@ func tracker() { tracks = make(map[string][]string) } sleeper.Reset(timeout) + case c := <-dumptracks: + if len(tracks) > 0 { + savetracks(tracks) + } + c <- true case <-endoftheworld: if len(tracks) > 0 { savetracks(tracks)