diff --git a/slopserver/db.py b/slopserver/db.py index 4a155c0..65a7573 100644 --- a/slopserver/db.py +++ b/slopserver/db.py @@ -1,9 +1,10 @@ from collections.abc import Iterable +from datetime import datetime from urllib.parse import ParseResult from sqlalchemy import select from sqlalchemy.engine import Engine from sqlalchemy.orm import Session -from slopserver.models import Domain, Path, User +from slopserver.models import Domain, Path, User, Report def select_slop(urls: list[ParseResult], engine: Engine) -> Iterable[Domain]: query = select(Domain).where(Domain.domain_name.in_(url[1] for url in urls)) @@ -11,7 +12,7 @@ def select_slop(urls: list[ParseResult], engine: Engine) -> Iterable[Domain]: rows = session.scalars(query).all() return rows -def insert_slop(urls: list[ParseResult], engine: Engine): +def insert_slop(urls: list[ParseResult], engine: Engine, user: User | None = None): domain_dict: dict[str. set[str]] = dict() for url in urls: if not domain_dict.get(url[1]): @@ -35,13 +36,25 @@ def insert_slop(urls: list[ParseResult], engine: Engine): new_domain = Domain(domain_name=domain, paths=list()) new_domain.paths = [Path(path=path) for path in paths] session.add(new_domain) + if user: + for path in new_domain.paths: + new_report = Report(path_id=path.id, user_id=user.id) + session.add(new_report) else: existing_domain = existing_dict[domain] existing_paths = set((path.path for path in existing_domain.paths)) for path in paths: if not path in existing_paths: - existing_domain.paths.append(Path(path=path)) + new_path = Path(path=path) + existing_domain.paths.append(new_path) + session.add(new_path) + session.flush([new_path]) + session.refresh(new_path) + if user: + new_report = Report( + path_id=new_path.id, user_id=user.id, timestamp=datetime.now()) + session.add(new_report) session.commit() diff --git a/slopserver/server.py b/slopserver/server.py index 47d0486..2b6ed40 100644 --- a/slopserver/server.py +++ b/slopserver/server.py @@ -95,15 +95,21 @@ def generate_auth_token(username): encoded_jwt = jwt.encode(bearer_token, TOKEN_SECRET, ALGO) return encoded_jwt +def get_token_user(decoded_token): + user = get_user(decoded_token["sub"], DB_ENGINE) + return user + def verify_auth_token(token: str): try: token = jwt.decode(token, TOKEN_SECRET, ALGO, audience="slopserver") + return token except: raise HTTPException(status_code=401, detail="invalid access token") @app.post("/report") async def report_slop(report: SlopReport, bearer: Annotated[str, AfterValidator(verify_auth_token), Header()]): - insert_slop(report.slop_urls, DB_ENGINE) + user = get_token_user(bearer) + insert_slop(report.slop_urls, DB_ENGINE, user) @app.post("/check") async def check_slop(check: Annotated[SlopReport, Body()], bearer: Annotated[str, AfterValidator(verify_auth_token), Header()]):