355 lines
13 KiB
Python
355 lines
13 KiB
Python
|
|
import os
|
||
|
|
import sqlite3
|
||
|
|
import zipfile
|
||
|
|
import io
|
||
|
|
import logging
|
||
|
|
from functools import wraps
|
||
|
|
from datetime import datetime
|
||
|
|
from flask import (
|
||
|
|
Flask, g, render_template, request, redirect, url_for, abort,
|
||
|
|
current_app, jsonify, send_file
|
||
|
|
)
|
||
|
|
import markdown2
|
||
|
|
from hmac import compare_digest as safe_str_cmp
|
||
|
|
from base64 import b64decode
|
||
|
|
|
||
|
|
logging.basicConfig(level=logging.INFO)
|
||
|
|
logger = logging.getLogger("textbooru")
|
||
|
|
|
||
|
|
BASE_DIR = os.path.dirname(__file__)
|
||
|
|
DB_PATH = os.path.join(BASE_DIR, "data.db")
|
||
|
|
|
||
|
|
# Basic auth credentials (set in .env or env vars)
|
||
|
|
ADMIN_USER = os.environ.get("ADMIN_USER", "admin")
|
||
|
|
ADMIN_PASS = os.environ.get("ADMIN_PASS", "password") # change this in production
|
||
|
|
|
||
|
|
app = Flask(__name__)
|
||
|
|
app.config['SECRET_KEY'] = os.environ.get("SECRET_KEY", "change-me")
|
||
|
|
|
||
|
|
# --- DB utils ----------------------------------------------------------------
|
||
|
|
def get_db():
|
||
|
|
db = getattr(g, "_database", None)
|
||
|
|
if db is None:
|
||
|
|
need_init = not os.path.exists(DB_PATH)
|
||
|
|
db = g._database = sqlite3.connect(DB_PATH, check_same_thread=False)
|
||
|
|
db.row_factory = sqlite3.Row
|
||
|
|
if need_init:
|
||
|
|
init_db(db)
|
||
|
|
return db
|
||
|
|
|
||
|
|
def init_db(db):
|
||
|
|
cur = db.cursor()
|
||
|
|
cur.executescript("""
|
||
|
|
CREATE TABLE IF NOT EXISTS posts (
|
||
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
|
|
title TEXT,
|
||
|
|
body TEXT,
|
||
|
|
created_at TEXT
|
||
|
|
);
|
||
|
|
CREATE VIRTUAL TABLE IF NOT EXISTS posts_fts USING fts5(title, body, content='posts', content_rowid='id');
|
||
|
|
CREATE TABLE IF NOT EXISTS tags (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT UNIQUE);
|
||
|
|
CREATE TABLE IF NOT EXISTS post_tags (post_id INTEGER, tag_id INTEGER);
|
||
|
|
CREATE TABLE IF NOT EXISTS votes (id INTEGER PRIMARY KEY AUTOINCREMENT, post_id INTEGER, voter_ip TEXT, vote INTEGER, created_at TEXT);
|
||
|
|
CREATE TABLE IF NOT EXISTS flags (id INTEGER PRIMARY KEY AUTOINCREMENT, post_id INTEGER, reason TEXT, created_at TEXT, resolved INTEGER DEFAULT 0);
|
||
|
|
""")
|
||
|
|
db.commit()
|
||
|
|
logger.info("Initialized DB at %s", DB_PATH)
|
||
|
|
|
||
|
|
@app.teardown_appcontext
|
||
|
|
def close_conn(exc):
|
||
|
|
db = getattr(g, "_database", None)
|
||
|
|
if db is not None:
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
# --- Helpers ----------------------------------------------------------------
|
||
|
|
def require_basic_auth(fn):
|
||
|
|
@wraps(fn)
|
||
|
|
def wrapper(*args, **kwargs):
|
||
|
|
auth = request.headers.get("Authorization")
|
||
|
|
if not auth or not auth.lower().startswith("basic "):
|
||
|
|
return _basic_auth_required()
|
||
|
|
try:
|
||
|
|
payload = b64decode(auth.split(" ", 1)[1]).decode("utf-8")
|
||
|
|
user, passwd = payload.split(":", 1)
|
||
|
|
except Exception:
|
||
|
|
return _basic_auth_required()
|
||
|
|
if not (safe_str_cmp(user, ADMIN_USER) and safe_str_cmp(passwd, ADMIN_PASS)):
|
||
|
|
return _basic_auth_required()
|
||
|
|
return fn(*args, **kwargs)
|
||
|
|
return wrapper
|
||
|
|
|
||
|
|
def _basic_auth_required():
|
||
|
|
resp = current_app.make_response("Authentication required")
|
||
|
|
resp.headers['WWW-Authenticate'] = 'Basic realm="TextBooru admin"'
|
||
|
|
resp.status_code = 401
|
||
|
|
return resp
|
||
|
|
|
||
|
|
def add_post(title, body, tags):
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
created = datetime.utcnow().isoformat()
|
||
|
|
cur.execute("INSERT INTO posts (title, body, created_at) VALUES (?, ?, ?)", (title, body, created))
|
||
|
|
post_id = cur.lastrowid
|
||
|
|
try:
|
||
|
|
cur.execute("INSERT INTO posts_fts (rowid, title, body) VALUES (?, ?, ?)", (post_id, title, body))
|
||
|
|
except Exception:
|
||
|
|
logger.exception("FTS insert failed")
|
||
|
|
for t in tags:
|
||
|
|
name = t.strip().lower()
|
||
|
|
if not name:
|
||
|
|
continue
|
||
|
|
cur.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (name,))
|
||
|
|
cur.execute("SELECT id FROM tags WHERE name = ?", (name,))
|
||
|
|
row = cur.fetchone()
|
||
|
|
if row:
|
||
|
|
tag_id = row["id"]
|
||
|
|
cur.execute("INSERT OR IGNORE INTO post_tags (post_id, tag_id) VALUES (?, ?)", (post_id, tag_id))
|
||
|
|
db.commit()
|
||
|
|
return post_id
|
||
|
|
|
||
|
|
def get_tags_for_post(post_id):
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
cur.execute("""
|
||
|
|
SELECT t.name FROM tags t
|
||
|
|
JOIN post_tags pt ON pt.tag_id = t.id
|
||
|
|
WHERE pt.post_id = ?
|
||
|
|
ORDER BY t.name
|
||
|
|
""", (post_id,))
|
||
|
|
return [r["name"] for r in cur.fetchall()]
|
||
|
|
|
||
|
|
def get_vote_counts(post_id):
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
cur.execute("SELECT SUM(vote) as score, SUM(CASE WHEN vote>0 THEN 1 ELSE 0 END) as ups, SUM(CASE WHEN vote<0 THEN 1 ELSE 0 END) as downs FROM votes WHERE post_id = ?", (post_id,))
|
||
|
|
row = cur.fetchone()
|
||
|
|
if not row:
|
||
|
|
return {"score":0,"ups":0,"downs":0}
|
||
|
|
return {"score": row["score"] or 0, "ups": row["ups"] or 0, "downs": abs(row["downs"] or 0)}
|
||
|
|
|
||
|
|
# --- Routes -----------------------------------------------------------------
|
||
|
|
@app.context_processor
|
||
|
|
def inject_tags():
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
try:
|
||
|
|
cur.execute("SELECT t.name, count(*) as cnt FROM tags t JOIN post_tags pt ON pt.tag_id = t.id GROUP BY t.id ORDER BY cnt DESC LIMIT 50")
|
||
|
|
tags = cur.fetchall()
|
||
|
|
except Exception:
|
||
|
|
tags = []
|
||
|
|
return dict(all_tags=tags)
|
||
|
|
|
||
|
|
@app.route("/")
|
||
|
|
def index():
|
||
|
|
page = max(1, int(request.args.get("page", 1)))
|
||
|
|
per_page = 20
|
||
|
|
offset = (page - 1) * per_page
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
cur.execute("SELECT count(*) as c FROM posts")
|
||
|
|
total = cur.fetchone()["c"] or 0
|
||
|
|
cur.execute("SELECT id, title, body, created_at FROM posts ORDER BY id DESC LIMIT ? OFFSET ?", (per_page, offset))
|
||
|
|
posts = cur.fetchall()
|
||
|
|
pages = (total + per_page - 1) // per_page
|
||
|
|
return render_template("index.html", posts=posts, get_tags=get_tags_for_post, markdown=markdown2.markdown, page=page, pages=pages)
|
||
|
|
|
||
|
|
@app.route("/post/<int:pid>")
|
||
|
|
def view_post(pid):
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
cur.execute("SELECT * FROM posts WHERE id = ?", (pid,))
|
||
|
|
post = cur.fetchone()
|
||
|
|
if not post:
|
||
|
|
abort(404)
|
||
|
|
tags = get_tags_for_post(pid)
|
||
|
|
votes = get_vote_counts(pid)
|
||
|
|
# fetch flags count
|
||
|
|
cur.execute("SELECT count(*) as c FROM flags WHERE post_id = ? AND resolved = 0", (pid,))
|
||
|
|
flags = cur.fetchone()["c"] or 0
|
||
|
|
return render_template("post.html", post=post, tags=tags, votes=votes, flags=flags, markdown=markdown2.markdown)
|
||
|
|
|
||
|
|
@app.route("/add", methods=["GET", "POST"])
|
||
|
|
def add_route():
|
||
|
|
if request.method == "POST":
|
||
|
|
title = request.form.get("title","").strip()
|
||
|
|
body = request.form.get("body","").strip()
|
||
|
|
tags = [t for t in (request.form.get("tags","").split(",") if request.form.get("tags") else [])]
|
||
|
|
if not body:
|
||
|
|
return "Body required", 400
|
||
|
|
pid = add_post(title, body, tags)
|
||
|
|
return redirect(url_for("view_post", pid=pid))
|
||
|
|
return render_template("add.html")
|
||
|
|
|
||
|
|
@app.route("/tag/<name>")
|
||
|
|
def tag_page(name):
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
cur.execute("""
|
||
|
|
SELECT p.id,p.title,p.body,p.created_at FROM posts p
|
||
|
|
JOIN post_tags pt ON pt.post_id = p.id
|
||
|
|
JOIN tags t ON t.id = pt.tag_id
|
||
|
|
WHERE t.name = ?
|
||
|
|
ORDER BY p.id DESC
|
||
|
|
LIMIT 100
|
||
|
|
""", (name.lower(),))
|
||
|
|
posts = cur.fetchall()
|
||
|
|
return render_template("index.html", posts=posts, get_tags=get_tags_for_post, current_tag=name, markdown=markdown2.markdown)
|
||
|
|
|
||
|
|
@app.route("/search")
|
||
|
|
def search():
|
||
|
|
q = request.args.get("q","").strip()
|
||
|
|
results = []
|
||
|
|
if q:
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
try:
|
||
|
|
cur.execute("SELECT rowid FROM posts_fts WHERE posts_fts MATCH ? LIMIT 200", (q,))
|
||
|
|
rows = cur.fetchall()
|
||
|
|
ids = [r["rowid"] for r in rows]
|
||
|
|
except Exception:
|
||
|
|
logger.warning("FTS failed, falling back to LIKE")
|
||
|
|
cur.execute("SELECT id FROM posts WHERE title LIKE ? OR body LIKE ? LIMIT 200", (f"%{q}%", f"%{q}%"))
|
||
|
|
rows = cur.fetchall()
|
||
|
|
ids = [r["id"] for r in rows]
|
||
|
|
if ids:
|
||
|
|
placeholders = ",".join(["?"]*len(ids))
|
||
|
|
cur.execute(f"SELECT id,title,body,created_at FROM posts WHERE id IN ({placeholders}) ORDER BY id DESC", ids)
|
||
|
|
results = cur.fetchall()
|
||
|
|
return render_template("search.html", q=q, posts=results, get_tags=get_tags_for_post, markdown=markdown2.markdown)
|
||
|
|
|
||
|
|
# --- Voting endpoints -------------------------------------------------------
|
||
|
|
@app.route("/vote/<int:pid>", methods=["POST"])
|
||
|
|
def vote(pid):
|
||
|
|
# vote value: 'up' or 'down'
|
||
|
|
v = request.form.get("vote")
|
||
|
|
vote_value = 1 if v == "up" else -1
|
||
|
|
voter_ip = request.remote_addr or "anon"
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
# prevent multiple same votes from same ip (simple prototype)
|
||
|
|
cur.execute("SELECT id, vote FROM votes WHERE post_id = ? AND voter_ip = ?", (pid, voter_ip))
|
||
|
|
row = cur.fetchone()
|
||
|
|
if row:
|
||
|
|
# if same vote, remove (toggle), if opposite, update
|
||
|
|
if row["vote"] == vote_value:
|
||
|
|
cur.execute("DELETE FROM votes WHERE id = ?", (row["id"],))
|
||
|
|
else:
|
||
|
|
cur.execute("UPDATE votes SET vote = ?, created_at = ? WHERE id = ?", (vote_value, datetime.utcnow().isoformat(), row["id"]))
|
||
|
|
else:
|
||
|
|
cur.execute("INSERT INTO votes (post_id, voter_ip, vote, created_at) VALUES (?, ?, ?, ?)", (pid, voter_ip, vote_value, datetime.utcnow().isoformat()))
|
||
|
|
db.commit()
|
||
|
|
return redirect(url_for("view_post", pid=pid))
|
||
|
|
|
||
|
|
# --- Tag autocomplete -------------------------------------------------------
|
||
|
|
@app.route("/_tags")
|
||
|
|
def tag_suggest():
|
||
|
|
q = request.args.get("q","").strip().lower()
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
if not q:
|
||
|
|
cur.execute("SELECT name FROM tags ORDER BY name LIMIT 50")
|
||
|
|
rows = cur.fetchall()
|
||
|
|
else:
|
||
|
|
cur.execute("SELECT name FROM tags WHERE name LIKE ? ORDER BY name LIMIT 25", (f"{q}%",))
|
||
|
|
rows = cur.fetchall()
|
||
|
|
return jsonify([r["name"] for r in rows])
|
||
|
|
|
||
|
|
# --- Flags / moderation ----------------------------------------------------
|
||
|
|
@app.route("/flag/<int:pid>", methods=["POST"])
|
||
|
|
def flag_post(pid):
|
||
|
|
reason = request.form.get("reason","").strip()
|
||
|
|
if not reason:
|
||
|
|
reason = "no reason"
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
cur.execute("INSERT INTO flags (post_id, reason, created_at, resolved) VALUES (?, ?, ?, 0)", (pid, reason, datetime.utcnow().isoformat()))
|
||
|
|
db.commit()
|
||
|
|
return redirect(url_for("view_post", pid=pid))
|
||
|
|
|
||
|
|
@app.route("/mod")
|
||
|
|
@require_basic_auth
|
||
|
|
def mod_index():
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
cur.execute("SELECT f.id,f.post_id,f.reason,f.created_at,f.resolved,p.title FROM flags f LEFT JOIN posts p ON p.id = f.post_id ORDER BY f.created_at DESC LIMIT 200")
|
||
|
|
flags = cur.fetchall()
|
||
|
|
return render_template("mod.html", flags=flags)
|
||
|
|
|
||
|
|
@app.route("/mod/resolve/<int:fid>", methods=["POST"])
|
||
|
|
@require_basic_auth
|
||
|
|
def mod_resolve(fid):
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
cur.execute("UPDATE flags SET resolved = 1 WHERE id = ?", (fid,))
|
||
|
|
db.commit()
|
||
|
|
return redirect(url_for("mod_index"))
|
||
|
|
|
||
|
|
# --- Bulk import (ZIP of .txt/.md) -----------------------------------------
|
||
|
|
@app.route("/api/import", methods=["POST"])
|
||
|
|
@require_basic_auth
|
||
|
|
def api_import():
|
||
|
|
"""
|
||
|
|
Accepts multipart/form-data with file field 'file' containing a ZIP archive.
|
||
|
|
Each .txt/.md file becomes a post. First non-empty line becomes title if short.
|
||
|
|
Optional query param 'tags' (comma separated) to assign default tags.
|
||
|
|
"""
|
||
|
|
if 'file' not in request.files:
|
||
|
|
return jsonify({"error":"file required (zip)"}), 400
|
||
|
|
f = request.files['file']
|
||
|
|
tags_param = request.form.get("tags","")
|
||
|
|
default_tags = [t.strip() for t in tags_param.split(",") if t.strip()]
|
||
|
|
try:
|
||
|
|
data = f.read()
|
||
|
|
z = zipfile.ZipFile(io.BytesIO(data))
|
||
|
|
except Exception as e:
|
||
|
|
return jsonify({"error":"invalid zip", "detail": str(e)}), 400
|
||
|
|
imported = []
|
||
|
|
for name in z.namelist():
|
||
|
|
if name.endswith("/") or name.startswith("__MACOSX"):
|
||
|
|
continue
|
||
|
|
if not (name.lower().endswith(".txt") or name.lower().endswith(".md")):
|
||
|
|
continue
|
||
|
|
try:
|
||
|
|
content = z.read(name).decode("utf-8", errors="replace").strip()
|
||
|
|
if not content:
|
||
|
|
continue
|
||
|
|
# heuristics: first non-empty line <= 80 chars -> title
|
||
|
|
title = ""
|
||
|
|
body = content
|
||
|
|
for line in content.splitlines():
|
||
|
|
if line.strip():
|
||
|
|
if len(line.strip()) <= 120:
|
||
|
|
title = line.strip()
|
||
|
|
# body remove first occurrence of that line
|
||
|
|
body = content.replace(line, "", 1).lstrip("\n")
|
||
|
|
break
|
||
|
|
pid = add_post(title, body, default_tags)
|
||
|
|
imported.append({"file":name, "post_id": pid})
|
||
|
|
except Exception as e:
|
||
|
|
logger.exception("Import error for %s", name)
|
||
|
|
return jsonify({"imported": imported, "count": len(imported)})
|
||
|
|
|
||
|
|
# --- Export helper: export posts as zip of md (optional) --------------------
|
||
|
|
@app.route("/api/export/zip")
|
||
|
|
@require_basic_auth
|
||
|
|
def api_export_zip():
|
||
|
|
db = get_db()
|
||
|
|
cur = db.cursor()
|
||
|
|
cur.execute("SELECT id,title,body,created_at FROM posts ORDER BY id")
|
||
|
|
rows = cur.fetchall()
|
||
|
|
mem = io.BytesIO()
|
||
|
|
with zipfile.ZipFile(mem, "w") as z:
|
||
|
|
for r in rows:
|
||
|
|
fname = f"post_{r['id']}.md"
|
||
|
|
content = (f"# {r['title']}\n\n" if r['title'] else "") + r['body']
|
||
|
|
z.writestr(fname, content)
|
||
|
|
mem.seek(0)
|
||
|
|
return send_file(mem, mimetype="application/zip", download_name="export_posts.zip", as_attachment=True)
|
||
|
|
|
||
|
|
# --- Simple health / info --------------------------------------------------
|
||
|
|
@app.route("/status")
|
||
|
|
def status():
|
||
|
|
return jsonify({"status":"ok"})
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
app.run(host="0.0.0.0", port=8080)
|