""" Database and utility helper functions for Flask app Extracts common patterns to reduce code duplication """ from flask import jsonify from functools import wraps import re import logging logger = logging.getLogger(__name__) # ============================================================================ # Response Helpers # ============================================================================ def success_response(data=None, status=200): """Standard success response""" if data is None: data = {'status': 'ok'} return jsonify(data), status def error_response(error_code, message=None, status=400): """Standard error response""" response = {'error': error_code} if message: response['message'] = message return jsonify(response), status def not_found_response(resource='Resource'): """Standard 404 response""" return error_response('not_found', f'{resource} not found', 404) def validation_error(field, message='Invalid or missing field'): """Standard validation error""" return error_response('validation_error', f'{field}: {message}', 400) # ============================================================================ # Input Sanitization # ============================================================================ def sanitize_text(text, max_length=None, remove_scripts=True): """Sanitize text input by removing scripts and limiting length""" if not text: return '' text = str(text).strip() if remove_scripts: text = re.sub(r']*>.*?', '', text, flags=re.IGNORECASE | re.DOTALL) if max_length: text = text[:max_length] return text def validate_id(id_value, max_length=255): """Validate ID format and length""" if not id_value: return False return len(str(id_value)) <= max_length # ============================================================================ # Database Helpers # ============================================================================ def get_or_404(query, error_msg='Resource not found'): """Get item from query or return 404""" item = query.first() if not item: raise NotFoundError(error_msg) return item class NotFoundError(Exception): """Exception for 404 cases""" pass def safe_db_operation(db, operation_func): """ Safely execute database operation with automatic rollback on error Args: db: Database session operation_func: Function that performs DB operations Returns: Tuple of (success: bool, result: any, error: str or None) """ try: result = operation_func() db.commit() return True, result, None except NotFoundError as e: db.rollback() return False, None, str(e) except Exception as e: db.rollback() logger.error(f"Database operation failed: {e}") return False, None, str(e) # ============================================================================ # Model Serialization # ============================================================================ def serialize_profile(profile, include_song_count=False, db=None): """Serialize Profile model to dict""" data = { 'id': profile.id, 'name': profile.name, 'first_name': profile.first_name, 'last_name': profile.last_name, 'default_key': profile.default_key, 'email': profile.email or '', 'contact_number': profile.contact_number or '', 'notes': profile.notes or '' } if include_song_count and db: from postgresql_models import ProfileSong data['song_count'] = db.query(ProfileSong).filter( ProfileSong.profile_id == profile.id ).count() return data def serialize_song(song, include_full_content=False): """Serialize Song model to dict""" data = { 'id': song.id, 'title': song.title, 'artist': song.artist, 'band': song.band, 'singer': song.singer } if include_full_content: data['lyrics'] = song.lyrics or '' data['chords'] = song.chords or '' else: # Preview only data['lyrics'] = (song.lyrics or '')[:200] if song.lyrics else '' data['chords'] = (song.chords or '')[:100] if song.chords else '' return data def serialize_plan(plan): """Serialize Plan model to dict""" return { 'id': plan.id, 'date': plan.date, 'profile_id': plan.profile_id, 'notes': plan.notes or '' } # ============================================================================ # Data Extraction and Validation # ============================================================================ def extract_profile_data(data): """Extract and validate profile data from request""" return { 'name': sanitize_text(data.get('name'), 255), 'first_name': sanitize_text(data.get('first_name'), 255), 'last_name': sanitize_text(data.get('last_name'), 255), 'default_key': sanitize_text(data.get('default_key', 'C'), 10), 'email': sanitize_text(data.get('email'), 255), 'contact_number': sanitize_text(data.get('contact_number'), 50), 'notes': sanitize_text(data.get('notes'), 5000) } def extract_song_data(data): """Extract and validate song data from request""" return { 'title': sanitize_text(data.get('title', 'Untitled'), 500), 'artist': sanitize_text(data.get('artist'), 500), 'band': sanitize_text(data.get('band'), 500), 'singer': sanitize_text(data.get('singer'), 500), 'lyrics': data.get('lyrics', ''), 'chords': data.get('chords', '') } def extract_plan_data(data): """Extract and validate plan data from request""" return { 'date': sanitize_text(data.get('date'), 50), 'profile_id': data.get('profile_id'), 'notes': sanitize_text(data.get('notes'), 5000) } # ============================================================================ # Query Helpers # ============================================================================ def search_songs(db, Song, query_string=''): """Search songs by query string - SQL injection safe""" items = db.query(Song).all() if not query_string: return items # Sanitize and limit query length q = str(query_string)[:500].lower().strip() # Remove any SQL-like characters that could be injection attempts q = re.sub(r'[;\\"\']', '', q) def matches(song): searchable = [ song.title or '', song.artist or '', song.band or '', song.singer or '' ] return any(q in field.lower() for field in searchable) return [s for s in items if matches(s)] def update_model_fields(model, data, field_mapping=None): """ Update model fields from data dict Args: model: SQLAlchemy model instance data: Dict with new values field_mapping: Optional dict mapping data keys to model attributes """ if field_mapping is None: field_mapping = {k: k for k in data.keys()} for data_key, model_attr in field_mapping.items(): if data_key in data: setattr(model, model_attr, data[data_key])