Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0d3c622
AI-ASSIST
BillKG-exe Nov 10, 2025
9389ed3
Update ShapeCompletionOverlay.jsx
BillKG-exe Dec 4, 2025
b1be6e0
Update ai-assistant.css
BillKG-exe Dec 4, 2025
af1c532
Update ShapeCompletionOverlay.jsx
BillKG-exe Dec 4, 2025
c6dc692
Update ShapeCompletionOverlay.jsx
BillKG-exe Dec 4, 2025
3e8a6de
Delete frontend/src/components/AI/ShapeCompletionOverlay.jsx
BillKG-exe Dec 4, 2025
8bf8690
Add files via upload
BillKG-exe Dec 4, 2025
97dfc57
Delete frontend/src/components/Canvas.js
BillKG-exe Dec 4, 2025
9dc9c79
Add files via upload
BillKG-exe Dec 4, 2025
ef76a1a
Delete frontend/src/components/AI/AIAssistantPanel.jsx
BillKG-exe Dec 4, 2025
a472ebb
Delete frontend/src/components/AI/AIAP.jsx
BillKG-exe Dec 4, 2025
ce15d0d
Add files via upload
BillKG-exe Dec 4, 2025
d02e70e
Delete backend/routes/ai_assistant.py
BillKG-exe Dec 4, 2025
36a8202
Add files via upload
BillKG-exe Dec 4, 2025
ef08d63
Delete backend/services/llm_service.py
BillKG-exe Dec 4, 2025
c692fed
Add files via upload
BillKG-exe Dec 4, 2025
5d54017
Update Toolbar.js
BillKG-exe Dec 5, 2025
e1680d9
Update useAIAssistant.js
BillKG-exe Dec 5, 2025
94e7b56
Update PromptInput.jsx
BillKG-exe Dec 5, 2025
0a83fc3
Update image_generation_service.py
BillKG-exe Dec 5, 2025
7405303
Testing push
Dec 5, 2025
8ea224f
Added toolbar icon for ai assist
Dec 5, 2025
347d1e1
Working features except beautify
Dec 5, 2025
bad518e
Working features
Dec 5, 2025
a61a0e0
Isolated components for clarity
Dec 6, 2025
7fbb86b
Implement AI drawing pipeline with LLM stroke generation
mobinano Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ node_modules
.env.*

# But include .env.example files
!.env.example
!.env.example*.tgz
7 changes: 3 additions & 4 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from routes.admin import admin_bp
from routes.frontend import frontend_bp
from routes.analytics import analytics_bp
from routes.export import export_bp
from routes.ai_assistant import ai_assistant_bp
from services.db import redis_client
from services.canvas_counter import get_canvas_draw_count
from services.graphql_service import commit_transaction_via_graphql
Expand Down Expand Up @@ -170,33 +170,32 @@ def handle_all_exceptions(e):
app.register_blueprint(undo_redo_bp)
app.register_blueprint(metrics_bp)
app.register_blueprint(auth_bp)
app.register_blueprint(ai_assistant_bp)
app.register_blueprint(rooms_bp)
app.register_blueprint(submit_room_line_bp)
app.register_blueprint(admin_bp)
app.register_blueprint(export_bp)

# Register versioned API v1 blueprints for external applications
from api_v1.auth import auth_v1_bp
from api_v1.canvases import canvases_v1_bp
from api_v1.collaborations import collaborations_v1_bp
from api_v1.notifications import notifications_v1_bp
from api_v1.users import users_v1_bp
from routes.stamps import stamps_bp
from api_v1.templates import templates_v1_bp

app.register_blueprint(auth_v1_bp)
app.register_blueprint(canvases_v1_bp)
app.register_blueprint(collaborations_v1_bp)
app.register_blueprint(notifications_v1_bp)
app.register_blueprint(users_v1_bp)
app.register_blueprint(stamps_bp, url_prefix='/api')
app.register_blueprint(templates_v1_bp)

# Frontend serving must be last to avoid route conflicts
app.register_blueprint(frontend_bp)
app.register_blueprint(analytics_bp)

if __name__ == '__main__':
print(SIGNER_PUBLIC_KEY, SIGNER_PRIVATE_KEY, RECIPIENT_PUBLIC_KEY)
if not redis_client.exists('res-canvas-draw-count'):
init_count = {"id": "res-canvas-draw-count", "value": 0}
logger = __import__('logging').getLogger(__name__)
Expand Down
1 change: 1 addition & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
ANALYTICS_COLLECTION_NAME = os.getenv("ANALYTICS_COLLECTION_NAME", "analytics_events")
ANALYTICS_AGGREGATES_COLLECTION = os.getenv("ANALYTICS_AGGREGATES_COLLECTION", "analytics_aggregates")
HUGGINGFACE_API_KEY=os.getenv("HUGGINGFACE_API_KEY")

JWT_SECRET = os.getenv("JWT_SECRET", "dev-insecure-change-me")
JWT_ISSUER = "rescanvas"
Expand Down
19 changes: 19 additions & 0 deletions backend/middleware/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,10 @@ def validate_stroke_payload(value) -> Tuple[bool, str]:

if "pathData" not in stroke:
return False, "Stroke must have pathData"

path_data = stroke.get("pathData")
if not isinstance(path_data, (dict, list)):
return False, "Stroke pathData must be an object or array"

is_valid, error = validate_color(stroke.get("color"))
if not is_valid:
Expand All @@ -498,6 +502,21 @@ def validate_stroke_payload(value) -> Tuple[bool, str]:
return False, "Line width must be between 1 and 100"
except (TypeError, ValueError):
return False, "Line width must be a number"

timestamp = stroke.get("timestamp")
if timestamp is not None:
try:
int(timestamp)
except (TypeError, ValueError):
return False, "Timestamp must be a number"

drawing_id = stroke.get("drawingId")
if drawing_id is not None and not isinstance(drawing_id, str):
return False, "drawingId must be a string"

stroke_id = stroke.get("id")
if stroke_id is not None and not isinstance(stroke_id, str):
return False, "id must be a string"

# Validate optional signature fields (will be enforced for secure rooms in handler)
signature = value.get("signature")
Expand Down
149 changes: 149 additions & 0 deletions backend/routes/ai_assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from flask import Blueprint, request, jsonify
from services.llm_service import prompt_to_drawings, complete_shape_from_canvas, beautify_canvas_state
# from services.image_generation_service import (
# text_to_image as img_text_to_image,
# )
import logging
import base64
import io

ai_assistant_bp = Blueprint('ai_assistant', __name__)
logger = logging.getLogger(__name__)


@ai_assistant_bp.route('/api/ai_assistant/drawing', methods=['POST'])
def text_to_drawings():
"""
Body: { "prompt": "<natural language description>", "canvasState": { ... } }
Returns: stroke JSON or an error payload.
"""
try:
payload = request.get_json(silent=True) or {}
prompt = payload.get("prompt")
canvas_state = payload.get("canvasState") or {}

if not isinstance(prompt, str) or not prompt.strip():
return jsonify({
"error": "bad_request",
"detail": "Missing or invalid 'prompt' (string)."
}), 400

if not isinstance(canvas_state, dict):
return jsonify({
"error": "bad_request",
"detail": "Invalid 'canvasState' (object)."
}), 400

logger.info("AI drawing requested: route entered")
logger.info("Calling prompt_to_drawings now")
result = prompt_to_drawings(prompt.strip(), canvas_state)
logger.info("prompt_to_drawings returned")

if isinstance(result, dict) and "error" in result:
logger.warning("AI drawing failed: %s", result)
return jsonify({
"error": "upstream_model_error",
"detail": result
}), 502

logger.info("AI drawing generated successfully")
logger.debug("AI drawing result: %r", result)
return jsonify(result), 200

except Exception as e:
logger.exception("Unhandled error in /drawing")
return jsonify({"error": "server_error", "detail": str(e)}), 500


@ai_assistant_bp.route('/api/ai_assistant/complete', methods=['POST'])
def shape_completion():
"""
Body: { "canvasState": { ... } }
Returns: { complete, confidence, object{ color, lineWidth, pathData{...} } } or an error payload.
"""
try:
payload = request.get_json(silent=True) or {}
canvas_state = payload.get("canvasState")
if not isinstance(canvas_state, dict):
return jsonify({"error": "bad_request", "detail": "Missing or invalid 'canvas_state' (object)."}), 400

logger.info("AI shape completion requested")
suggestion = complete_shape_from_canvas(canvas_state)

if not isinstance(canvas_state, dict):
return jsonify({
"error": "bad_request",
"detail": "Missing or invalid 'canvasState' (object)."
}), 400

return jsonify(suggestion), 200
except Exception as e:
logger.exception("Unhandled error in /complete")
return jsonify({"error": "server_error", "detail": str(e)}), 500


@ai_assistant_bp.route('/api/ai_assistant/image', methods=['POST'])
def text_to_image():
"""
TODO: To be implemented
Body: { "prompt": "<string>", "width"?: int, "height"?: int, "style"?: str }
Returns: { "imageDataUrl": "data:image/png;base64,..." }
"""
try:
payload = request.get_json(silent=True) or {}
prompt = payload.get("prompt", "")
width = payload.get("width") or 512
height = payload.get("height") or 512
style = payload.get("style") or "default"

if not isinstance(prompt, str) or not prompt.strip():
return jsonify({
"error": "bad_request",
"detail": "Missing or invalid 'prompt' (string)."
}), 400

logger.info("AI text-to-image requested")

pil_image = [] # img_text_to_image(prompt.strip(), width=width, height=height, style=style)

buf = io.BytesIO()
pil_image.save(buf, format="PNG")
buf.seek(0)
encoded = base64.b64encode(buf.read()).decode("utf-8")
data_url = f"data:image/png;base64,{encoded}"

return jsonify({"imageDataUrl": data_url}), 200

except Exception as e:
logger.exception("Unhandled error in /image")
return jsonify({"error": "server_error", "detail": str(e)}), 500


@ai_assistant_bp.route("/api/ai_assistant/beautify", methods=["POST"])
def beautify_sketch():
try:
payload = request.get_json(silent=True) or {}
canvas_state = payload.get("canvasState")

if not isinstance(canvas_state, dict):
return jsonify({
"error": "bad_request",
"detail": "Missing or invalid 'canvasState' (object)."
}), 400

result = beautify_canvas_state(canvas_state)
# print("\n\ncanvas_state!!!", canvas_state, "\n\n")
# print("\n\nResult!!!", result, "\n\n")

if not isinstance(result, dict) or "objects" not in result:
logger.warning("Beautify returned invalid payload: %r", result)
return jsonify({
"error": "upstream_model_error",
"detail": "Beautify model returned invalid payload."
}), 502

return jsonify(result), 200

except Exception as e:
logger.exception("Unhandled error in /beautify")
return jsonify({"error": "server_error", "detail": str(e)}), 500
40 changes: 39 additions & 1 deletion backend/routes/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,10 @@ def admin_fill_wrapped_key(roomId):
@require_room_access(room_id_param="roomId")
@limiter.limit(f"{RATE_LIMIT_STROKE_MINUTE}/minute")
@validate_request_data({
"stroke": {"validator": lambda v: (isinstance(v, dict), "Stroke must be an object") if not isinstance(v, dict) else (True, None), "required": True},
"stroke": {
"validator": lambda v: validate_stroke_payload(request.get_json() or {}),
"required": True
},
"signature": {"validator": validate_optional_string(max_length=1000), "required": False},
"signerPubKey": {"validator": validate_optional_string(max_length=1000), "required": False}
})
Expand Down Expand Up @@ -732,6 +735,19 @@ def post_stroke(roomId):
parent_paste_id = stroke.get("parentPasteId", "NOT SET")
logger.warning(f"POST STROKE DEBUG - roomId={roomId}, brushType={brush_type}, brushParams={brush_params}, parentPasteId={parent_paste_id}")
logger.warning(f"POST STROKE DEBUG - Full stroke object: {json.dumps(stroke, default=str)}")
path_data = stroke.get("pathData")
is_ai_batch_marker = isinstance(path_data, dict) and path_data.get("tool") == "paste" and path_data.get("aiGenerated") is True
is_ai_child_stroke = bool(stroke.get("parentPasteId")) or (
isinstance(path_data, dict) and bool(path_data.get("parentPasteId"))
)
if is_ai_batch_marker or is_ai_child_stroke:
logger.info(
"AI drawing submitted: roomId=%s drawingId=%s parentPasteId=%s batchMarker=%s",
roomId,
stroke.get("drawingId") or stroke.get("id"),
stroke.get("parentPasteId") or (path_data.get("parentPasteId") if isinstance(path_data, dict) else None),
is_ai_batch_marker,
)
except Exception as e:
logger.error(f"POST STROKE DEBUG - Error logging stroke: {e}")

Expand Down Expand Up @@ -827,6 +843,20 @@ def post_stroke(roomId):
logger.warning(f"STORING FULL STROKE: {json.dumps(stroke, default=str)[:500]}...")

strokes_coll.insert_one({"roomId": roomId, "ts": stroke["ts"], "stroke": stroke})
try:
path_data = stroke.get("pathData")
if (
stroke.get("parentPasteId") or
(isinstance(path_data, dict) and (path_data.get("parentPasteId") or path_data.get("aiGenerated")))
):
logger.info(
"AI drawing stored: roomId=%s strokeId=%s parentPasteId=%s",
roomId,
stroke.get("id") or stroke.get("drawingId"),
stroke.get("parentPasteId") or (path_data.get("parentPasteId") if isinstance(path_data, dict) else None),
)
except Exception:
logger.exception("Failed to emit AI storage log for room %s", roomId)

rooms_coll.update_one({"_id": room["_id"]}, {"$set": {"updatedAt": datetime.utcnow()}})

Expand Down Expand Up @@ -1340,6 +1370,14 @@ def get_strokes(roomId):
"filterParams": stroke_data["filterParams"],
}

if parent_paste_id or (isinstance(stroke_data.get("pathData"), dict) and stroke_data["pathData"].get("aiGenerated")):
logger.info(
"AI drawing restored: roomId=%s strokeId=%s parentPasteId=%s",
roomId,
stroke_id,
parent_paste_id,
)

filtered_strokes.append(stroke_data)
if stroke_id:
seen_stroke_ids.add(stroke_id)
Expand Down
18 changes: 17 additions & 1 deletion backend/services/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import threading
import redis
import os
from pymongo import MongoClient
from pymongo.server_api import ServerApi
import logging
Expand Down Expand Up @@ -65,7 +66,22 @@

settings_coll = mongo_client[DB_NAME]["settings"]

redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=0)
DISABLE_REDIS = os.getenv("DISABLE_REDIS", "false").lower() == "true"

if DISABLE_REDIS:
try:
import fakeredis
redis_client = fakeredis.FakeStrictRedis(decode_responses=False)
print("Using fakeredis fallback (no real Redis server).")
except ImportError:
raise RuntimeError(
"DISABLE_REDIS=true but fakeredis is not installed. "
"Run: python -m pip install fakeredis"
)
else:
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=0)

lock = threading.Lock()

Expand Down
Empty file.
Loading