first part of a do not federate firewall

This commit is contained in:
Ted Unangst 2019-06-24 21:14:47 -04:00
parent 222efa77ea
commit 42f6aab4e5
2 changed files with 39 additions and 0 deletions

22
honk.go
View File

@ -107,6 +107,17 @@ func getInfo(r *http.Request) map[string]interface{} {
return templinfo return templinfo
} }
var donotfedafterdark = make(map[string]bool)
func stealthed(r *http.Request) bool {
addr := r.Header.Get("X-Forwarded-For")
fake := donotfedafterdark[addr]
if fake {
log.Printf("faking 404 for %s", addr)
}
return fake
}
func homepage(w http.ResponseWriter, r *http.Request) { func homepage(w http.ResponseWriter, r *http.Request) {
templinfo := getInfo(r) templinfo := getInfo(r)
u := login.GetUserInfo(r) u := login.GetUserInfo(r)
@ -449,6 +460,11 @@ func outbox(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
if stealthed(r) {
http.NotFound(w, r)
return
}
honks := gethonksbyuser(name, false) honks := gethonksbyuser(name, false)
var jonks []map[string]interface{} var jonks []map[string]interface{}
@ -546,6 +562,11 @@ func showhonk(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
if stealthed(r) {
http.NotFound(w, r)
return
}
xid := fmt.Sprintf("https://%s%s", serverName, r.URL.Path) xid := fmt.Sprintf("https://%s%s", serverName, r.URL.Path)
h := getxonk(user.ID, xid) h := getxonk(user.ID, xid)
if h == nil || !h.Public { if h == nil || !h.Public {
@ -1539,6 +1560,7 @@ func main() {
} }
getconfig("servermsg", &serverMsg) getconfig("servermsg", &serverMsg)
getconfig("servername", &serverName) getconfig("servername", &serverName)
getconfig("dnf", &donotfedafterdark)
prepareStatements(db) prepareStatements(db)
switch cmd { switch cmd {
case "adduser": case "adduser":

17
util.go
View File

@ -260,6 +260,23 @@ func opendatabase() *sql.DB {
} }
func getconfig(key string, value interface{}) error { func getconfig(key string, value interface{}) error {
m, ok := value.(*map[string]bool)
if ok {
rows, err := stmtConfig.Query(key)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var s string
err = rows.Scan(&s)
if err != nil {
return err
}
(*m)[s] = true
}
return nil
}
row := stmtConfig.QueryRow(key) row := stmtConfig.QueryRow(key)
err := row.Scan(value) err := row.Scan(value)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {