Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/hma-release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ name: Publish HMA docker image

on:
push:
branches:
- main
paths:
- "hasher-matcher-actioner/version.txt"
- ".github/workflows/hma-release.yaml"

env:
REGISTRY: ghcr.io
Expand Down Expand Up @@ -57,5 +56,4 @@ jobs:
context: hasher-matcher-actioner
platforms: linux/amd64,linux/arm64
tags: |
${{ env.REGISTRY }}/${{ env.IMAGE_NAME_LC }}/hma:latest
${{ env.REGISTRY }}/${{ env.IMAGE_NAME_LC }}/hma:${{ env.VERSION }}
11 changes: 9 additions & 2 deletions hasher-matcher-actioner/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,19 @@ all = [
"types-psycopg2",
"types-python-dateutil",
"gunicorn",
"flask_apscheduler"
"flask_apscheduler",
"psutil>=5.9.0",
"pympler>=0.9"
]

test = [ "pytest" ]

prod = [ "gunicorn" ]
prod = [
"gunicorn",
"psutil>=5.9.0",
"pympler>=0.9"
]


[tool.mypy]
warn_unused_configs = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
# Note: When ALLOWED_HOSTNAMES is not set or empty, all hostnames are allowed

# We always want some logging by default, otherwise it's hard to tell whats happening inside the container:
FLASK_LOGGING_CONFIG = dictConfig(
dictConfig(
{
"version": 1,
"formatters": {
Expand All @@ -92,6 +92,7 @@
}
)


# If you need to add something to the Flask app, the following hook function can
# be used. Note that adding functionality may or may not prevent the UI from
# working, especially if authentication requirements are added to the API.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
],
)

FLASK_LOGGING_CONFIG = dictConfig(
dictConfig(
{
"version": 1,
"formatters": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
DBUSER = os.environ.get("POSTGRES_USER", "media_match")
DBPASS = os.environ.get("POSTGRES_PASSWORD", "hunter2")
DBHOST = os.environ.get("POSTGRES_HOST", os.environ.get("POSTGRESS_HOST", "db"))
DBNAME = os.environ.get("POSTGRES_DBNAME", os.environ.get("POSTGRESS_DBNAME", "media_match"))
DBNAME = os.environ.get(
"POSTGRES_DBNAME", os.environ.get("POSTGRESS_DBNAME", "media_match")
)
DATABASE_URI = f"postgresql+psycopg2://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}"

# Role configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
DBUSER = os.environ.get("POSTGRES_USER", "media_match")
DBPASS = os.environ.get("POSTGRES_PASSWORD", "hunter2")
DBHOST = os.environ.get("POSTGRES_HOST", os.environ.get("POSTGRESS_HOST", "db"))
DBNAME = os.environ.get("POSTGRES_DBNAME", os.environ.get("POSTGRESS_DBNAME", "media_match"))
DBNAME = os.environ.get(
"POSTGRES_DBNAME", os.environ.get("POSTGRESS_DBNAME", "media_match")
)
DATABASE_URI = f"postgresql+psycopg2://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}"

# Role configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
DBUSER = os.environ.get("POSTGRES_USER", "media_match")
DBPASS = os.environ.get("POSTGRES_PASSWORD", "hunter2")
DBHOST = os.environ.get("POSTGRES_HOST", os.environ.get("POSTGRESS_HOST", "db"))
DBNAME = os.environ.get("POSTGRES_DBNAME", os.environ.get("POSTGRESS_DBNAME", "media_match"))
DBNAME = os.environ.get(
"POSTGRES_DBNAME", os.environ.get("POSTGRESS_DBNAME", "media_match")
)
DATABASE_URI = f"postgresql+psycopg2://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}"

# Role configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
DBUSER = os.environ.get("POSTGRES_USER", "media_match")
DBPASS = os.environ.get("POSTGRES_PASSWORD", "hunter2")
DBHOST = os.environ.get("POSTGRES_HOST", os.environ.get("POSTGRESS_HOST", "db"))
DBNAME = os.environ.get("POSTGRES_DBNAME", os.environ.get("POSTGRESS_DBNAME", "media_match"))
DBNAME = os.environ.get(
"POSTGRES_DBNAME", os.environ.get("POSTGRESS_DBNAME", "media_match")
)
DATABASE_URI = f"postgresql+psycopg2://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}"

# Role configuration
Expand Down
53 changes: 53 additions & 0 deletions hasher-matcher-actioner/src/OpenMediaMatch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from OpenMediaMatch.persistence import get_storage
from OpenMediaMatch.blueprints import development, hashing, matching, curation, ui
from OpenMediaMatch.utils import dev_utils
from OpenMediaMatch.utils.memory_utils import log_memory_info


def _is_debug_mode():
Expand All @@ -58,6 +59,42 @@ def _setup_task_logging(app_logger: logging.Logger):
build_index.logger = app_logger.getChild("Indexer")


def _setup_memory_monitoring(app: flask.Flask):
"""Setup memory monitoring middleware and periodic logging"""
import time

# Track when we last logged memory to avoid spamming logs
last_memory_log = {"time": 0.0}
memory_log_interval = app.config.get(
"MEMORY_LOG_INTERVAL_SECONDS", 300
) # Default 5 min

@app.before_request
def log_memory_before_request():
"""Log memory info before processing requests (rate-limited)"""
current_time = time.time()
if current_time - last_memory_log["time"] >= memory_log_interval:
log_memory_info("Request Start", app.logger)
last_memory_log["time"] = current_time
flask.g.memory_logged = True
else:
flask.g.memory_logged = False

@app.after_request
def log_memory_after_request(response):
"""Log memory info after processing requests (only if we logged before)"""
if getattr(flask.g, "memory_logged", False):
log_memory_info("Request End", app.logger)
return response

app.logger.info(
"Memory monitoring enabled (interval: %d seconds)", memory_log_interval
)

# Log initial memory state
log_memory_info("App Startup", app.logger)


def create_app() -> flask.Flask:
"""
Create and configure the Flask app
Expand Down Expand Up @@ -107,6 +144,10 @@ def create_app() -> flask.Flask:

_setup_task_logging(app.logger)

# Add memory monitoring middleware if enabled
if app.config.get("ENABLE_MEMORY_MONITORING", False):
_setup_memory_monitoring(app)

scheduler: APScheduler | None = None

with app.app_context():
Expand Down Expand Up @@ -145,6 +186,18 @@ def create_app() -> flask.Flask:
start_date=now + datetime.timedelta(seconds=15),
)
app.logger.info("Started Apscheduler, initial tasks: %s", tasks)

# Add periodic memory monitoring task if enabled
if app.config.get("ENABLE_MEMORY_MONITORING", False):
scheduler.add_job(
"MemoryMonitor",
lambda: log_memory_info("Periodic Check", app.logger),
trigger="interval",
seconds=int(app.config.get("MEMORY_LOG_INTERVAL_SECONDS", 300)),
start_date=now + datetime.timedelta(seconds=60),
)
app.logger.info("Added periodic memory monitoring task")

scheduler.start()

storage.init_flask(app)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import logging
import time
import gc
import typing as t
from typing import Optional

from threatexchange.signal_type.signal_base import SignalType

Expand All @@ -16,6 +18,7 @@
)
from OpenMediaMatch.utils.time_utils import duration_to_human_str
from OpenMediaMatch.utils.memory_utils import trim_process_memory
from OpenMediaMatch.utils.memory_monitoring import MemoryMonitor

logger = logging.getLogger(__name__)

Expand All @@ -39,9 +42,24 @@ def build_all_indices(
start = time.time()
logger.info("Running the %s background task", build_all_indices.__name__)
enabled = signal_type_cfgs.get_enabled_signal_types()

# Monitor memory before starting all indices
monitor = MemoryMonitor(enable_detailed_profiling=False)
logger.info(monitor.log_snapshot("Before building all indices"))

for st in enabled.values():
# Force rebuild by clearing the checkpoint
# index_store.store_signal_type_index(
# st,
# st.get_index_cls().build([]),
# SignalTypeIndexBuildCheckpoint.get_empty()
# )
build_index(st, bank_store, index_store)

# Monitor memory after building all indices
logger.info(monitor.log_snapshot("After building all indices"))
logger.info(monitor.log_memory_trend())

logger.info(
"Completed %s background task - %s",
build_all_indices.__name__,
Expand All @@ -64,28 +82,56 @@ def build_index(
# First check to see if new signals have appeared since the last build
idx_checkpoint = index_store.get_last_index_build_checkpoint(for_signal_type)
bank_checkpoint = bank_store.get_current_index_build_target(for_signal_type)

if idx_checkpoint == bank_checkpoint:
logger.info("%s index up to date, no build needed", for_signal_type.get_name())
return

logger.info(
"Building index for %s (%d signals)",
for_signal_type.get_name(),
0 if bank_checkpoint is None else bank_checkpoint.total_hash_count,
)

# Use try/finally to ensure aggressive memory trim after build
# Use try/finally to ensure cleanup happens even on exceptions
signal_count = 0
built_index: t.Any | None = None # keep in locals per review nit
built_index: t.Any | None = None

try:
built_index, checkpoint, signal_count = _prepare_index(
# Prepare index with memory monitoring
built_index, checkpoint, signal_count, monitor = _prepare_index(
for_signal_type, bank_store
)

# Monitor memory during index storage
logger.info(
monitor.log_snapshot(
f"Before index storage for {for_signal_type.get_name()}"
)
)
index_store.store_signal_type_index(for_signal_type, built_index, checkpoint)
logger.info(
monitor.log_snapshot(
f"After index storage for {for_signal_type.get_name()}"
)
)

finally:
# Guaranteed cleanup even if exceptions occur
logger.info(
monitor.log_snapshot(f"Before cleanup for {for_signal_type.get_name()}")
)

# Force garbage collection to reclaim memory and attempt to free pages
trim_process_memory(logger, "Indexer")

logger.info(
monitor.log_snapshot(f"After cleanup for {for_signal_type.get_name()}")
)

# Log final memory trends
logger.info(monitor.log_memory_trend())

logger.info(
"Indexed %d signals for %s - %s",
signal_count,
Expand All @@ -97,23 +143,55 @@ def build_index(
def _prepare_index(
for_signal_type: t.Type[SignalType],
bank_store: IBankStore,
) -> tuple[t.Any, SignalTypeIndexBuildCheckpoint, int]:
) -> tuple[t.Any, SignalTypeIndexBuildCheckpoint, int, MemoryMonitor]:
"""
Collect signals for the given type, build the index, and compute checkpoint.
Returns a tuple of (built_index, checkpoint, signal_count).
Returns a tuple of (built_index, checkpoint, signal_count, monitor).
"""
# Memory monitoring is always enabled for diagnostics
monitor = MemoryMonitor(enable_detailed_profiling=True)

signal_list: list[tuple[str, int]] = []
signal_count = 0
last_cs = None

# Monitor memory during signal collection
logger.info(
monitor.log_snapshot(
f"Before signal collection for {for_signal_type.get_name()}"
)
)

# Collect signals
for last_cs in bank_store.bank_yield_content(for_signal_type):
signal_list.append((last_cs.signal_val, last_cs.bank_content_id))
signal_count += 1
if signal_count % 10000 == 0: # Log memory every 10k signals
logger.info(
monitor.log_snapshot(
f"After collecting {signal_count} signals for {for_signal_type.get_name()}"
)
)

logger.info(
monitor.log_snapshot(
f"After signal collection for {for_signal_type.get_name()}"
)
)

# Build index
# Monitor memory during index building
logger.info(
monitor.log_snapshot(
f"Before index construction for {for_signal_type.get_name()}"
)
)
index_cls = for_signal_type.get_index_cls()
built_index = index_cls.build(signal_list)
logger.info(
monitor.log_snapshot(
f"After index construction for {for_signal_type.get_name()}"
)
)

# Create checkpoint
checkpoint = SignalTypeIndexBuildCheckpoint.get_empty()
Expand All @@ -124,4 +202,4 @@ def _prepare_index(
total_hash_count=signal_count,
)

return built_index, checkpoint, signal_count
return built_index, checkpoint, signal_count, monitor
Loading
Loading