import gradio as gr
from transformers_js_py import pipeline
import json
import re
import js
# ========== 1. Load models ==========
embedder = await pipeline("feature-extraction", "Xenova/all-MiniLM-L6-v2")
qa_pipe = await pipeline("question-answering", "Xenova/distilbert-base-cased-distilled-squad")
summarizer_pipe = await pipeline("summarization", "Xenova/distilbart-cnn-12-6")
# ========== 2. Globals ==========
all_chunks = []
in_memory_cache = None
current_mode = "QA"
top_k = 1
last_top_chunks = []
# ========== 3. Static list of URLs ==========
urls = [
"https://cmdp.lamps.yorku.ca/data-of-peel-region/",
"https://cmdp.lamps.yorku.ca/ontario/",
"https://cmdp.lamps.yorku.ca/global/",
"https://cmdp.lamps.yorku.ca/other-vbds/",
"https://cmdp.lamps.yorku.ca/data-tables/",
"https://cmdp.lamps.yorku.ca/station-observation/",
"https://cmdp.lamps.yorku.ca/era5-land/",
"https://cmdp.lamps.yorku.ca/ar6-data/",
"https://cmdp.lamps.yorku.ca/climate-datasets/",
"https://cmdp.lamps.yorku.ca/gis/",
"https://cmdp.lamps.yorku.ca/cmdp/",
"https://cmdp.lamps.yorku.ca/"
]
# ========== 4. Helpers ==========
async def embed(text):
out = await embedder(text, pooling="mean", normalize=True)
return list(out.data)
def cosine_sim(a, b):
dot = sum([x*y for x,y in zip(a,b)])
normA = sum([x*x for x in a]) ** 0.5
normB = sum([x*x for x in b]) ** 0.5
if normA == 0 or normB == 0:
return 0.0
return dot / (normA * normB)
def chunk_text(text, size=300):
words = text.split()
return [" ".join(words[i:i+size]) for i in range(0, len(words), size)]
async def fetch_texts(urls):
docs = []
for url in urls:
try:
res = await js.fetch(url)
html = str(await res.text())
text = re.sub(r"<[^>]*>", " ", html)
text = re.sub(r"\s+", " ", text).strip()
docs.append({"url": url, "text": text})
except Exception as e:
# Provide more helpful error message for CORS issues
if "Failed to fetch" in str(e):
error_msg = "CORS error: This URL cannot be accessed from browser due to security restrictions."
else:
error_msg = f"Could not fetch. ({e})"
docs.append({"url": url, "text": error_msg})
return docs
def save_cache(chunks):
global in_memory_cache
in_memory_cache = chunks
def load_cache():
return in_memory_cache
def clear_cache_fn():
global all_chunks, last_top_chunks, in_memory_cache
all_chunks = []
last_top_chunks = []
in_memory_cache = None
return "🗑️ Cache cleared. Data will be reprocessed on next load."
# ========== 5. Preprocess static URLs ==========
async def preprocess_urls():
global all_chunks
cached = load_cache()
if cached:
all_chunks = cached
return "✅ Loaded embeddings from in-memory cache!"
docs = await fetch_texts(urls)
all_chunks = []
for doc in docs:
for c in chunk_text(doc["text"]):
all_chunks.append({"source": doc["url"], "text": c})
# compute embeddings
for chunk in all_chunks:
chunk["embedding"] = await embed(chunk["text"])
save_cache(all_chunks)
return "✅ URLs processed and embeddings ready!"
# Run preprocessing immediately on load
await preprocess_urls()
# ========== 6. Answer function ==========
async def answer(message, chat_history):
global last_top_chunks
if not all_chunks:
return "⚠️ Data not loaded yet."
# Filter out chunks with error messages and collect failed URLs
valid_chunks = [c for c in all_chunks if not c["text"].startswith("CORS error:") and not c["text"].startswith("Could not fetch")]
failed_urls = [c for c in all_chunks if c["text"].startswith("CORS error:") or c["text"].startswith("Could not fetch")]
if not valid_chunks:
failed_list = "\n".join([f"❌ {c['source']}" for c in failed_urls])
return f"⚠️ No valid data available. All URLs failed to load due to CORS restrictions.\n\n**Failed URLs:**\n{failed_list}\n\n**Solution:** Deploy this chatbot to the same domain (cmdp.lamps.yorku.ca) to avoid CORS restrictions."
q_emb = await embed(message)
scored = [{"chunk": c, "score": cosine_sim(q_emb, c["embedding"])} for c in valid_chunks]
scored.sort(key=lambda x: x["score"], reverse=True)
last_top_chunks = [s["chunk"] for s in scored[:top_k]]
combined_text = " ".join([c["text"] for c in last_top_chunks])
if current_mode == "QA":
result = await qa_pipe(question=message, context=combined_text)
bot_msg = (
f"📖 Based on Top-{top_k} chunks:\n\n"
f"👉 **Answer (QA Mode):** {result['answer']}\n\n"
f"🔎 **Source snippets:**\n" +
"\n".join([f"[{i+1}] From {c['source']}:\n{c['text'][:600]}..." for i, c in enumerate(last_top_chunks)])
)
else:
result = await summarizer_pipe(combined_text, max_length=240)
summary_text = result[0]["summary_text"]
bot_msg = (
f"📖 Based on Top-{top_k} chunks:\n\n"
f"📝 **Summary (Summarization Mode):** {summary_text}\n\n"
f"🔎 **Source snippets:**\n" +
"\n".join([f"[{i+1}] From {c['source']}:\n{c['text'][:600]}..." for i, c in enumerate(last_top_chunks)])
)
return bot_msg
# ========== 7. Gradio UI ==========
with gr.Blocks() as demo:
gr.Markdown("## 📚 Semantic Chatbot — Static URLs")
with gr.Row():
mode_toggle = gr.Radio(["QA", "Summarization"], label="Choose Mode", value="QA")
top_k_slider = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Top-K Chunks")
chat = gr.ChatInterface(answer, title="🤖 Semantic Chatbot",
description="Ask questions from static URLs. Toggle QA/Summarization and adjust Top-K chunks.")
# Event handlers
def set_mode(mode):
global current_mode
current_mode = mode
mode_toggle.change(fn=set_mode, inputs=mode_toggle, outputs=[])
def set_top_k(k):
global top_k
top_k = k
top_k_slider.change(fn=set_top_k, inputs=top_k_slider, outputs=[])
demo.launch()
transformers-js-py