Created
March 9, 2026 22:49
-
-
Save mbstacy/db40f6025584442d15951245d538d951 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| library(ellmer) | |
| library(httr2) | |
| # Bypass AWS credential requirements - routes through Harvard Apigee gateway | |
| Sys.setenv( | |
| HUIT_API_KEY = Sys.getenv("APIGEE_API_KEY"), | |
| AWS_ACCESS_KEY_ID = "placeholder", | |
| AWS_SECRET_ACCESS_KEY = "placeholder", | |
| AWS_REGION = "us-east-1" | |
| ) | |
| BASE_URL <- "https://go.apis.huit.harvard.edu/ais-bedrock-llm/v2" | |
| EMBED_MODEL <- "amazon.titan-embed-text-v2:0" | |
| # Chat client | |
| mychat <- chat_aws_bedrock( | |
| base_url = BASE_URL, | |
| model = "us.anthropic.claude-sonnet-4-5-20250929-v1:0", | |
| api_headers = c("x-api-key" = Sys.getenv("HUIT_API_KEY")) | |
| ) | |
| # Embed function - returns a matrix (one row per text) | |
| embed_huit <- function(texts, model = EMBED_MODEL) { | |
| results <- lapply(texts, function(text) { | |
| resp <- request(paste0(BASE_URL, "/model/", model, "/invoke")) |> | |
| req_headers( | |
| "x-api-key" = Sys.getenv("HUIT_API_KEY"), | |
| "Content-Type" = "application/json" | |
| ) |> | |
| req_body_json(list(inputText = text)) |> | |
| req_perform() | |
| as.numeric(unlist(resp_body_json(resp)$embedding)) | |
| }) | |
| do.call(rbind, results) | |
| } | |
| # Cosine similarity between two numeric vectors | |
| cosine_sim <- function(a, b) { | |
| sum(a * b) / (sqrt(sum(a^2)) * sqrt(sum(b^2))) | |
| } | |
| # Retrieve top_k most similar documents to a query | |
| retrieve <- function(query, doc_texts, doc_embeddings, top_k = 3) { | |
| query_vec <- embed_huit(query)[1, ] | |
| scores <- apply(doc_embeddings, 1, cosine_sim, b = query_vec) | |
| top_idx <- order(scores, decreasing = TRUE)[seq_len(top_k)] | |
| doc_texts[top_idx] | |
| } | |
| # Answer a question using retrieved context | |
| rag_chat <- function(question, doc_texts, doc_embeddings, top_k = 3) { | |
| context <- retrieve(question, doc_texts, doc_embeddings, top_k) | |
| context_str <- paste(context, collapse = "\n\n---\n\n") | |
| prompt <- paste0( | |
| "Use the following context to answer the question.\n\n", | |
| "Context:\n", context_str, "\n\n", | |
| "Question: ", question | |
| ) | |
| mychat$chat(prompt) | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Example usage | |
| # --------------------------------------------------------------------------- | |
| doc_texts <- c( | |
| "R is a programming language for statistical computing and graphics.", | |
| "The ellmer package provides tools for calling LLMs from R.", | |
| "Embeddings are numeric vector representations of text used in semantic search.", | |
| "RAG stands for Retrieval-Augmented Generation, combining search with LLMs.", | |
| "AWS Bedrock is a managed service for running foundation models on AWS." | |
| ) | |
| cat("Embedding documents...\n") | |
| doc_embeddings <- embed_huit(doc_texts) | |
| question <- "What is RAG?" | |
| cat("\nQuestion:", question, "\n\n") | |
| answer <- rag_chat(question, doc_texts, doc_embeddings) | |
| cat("Answer:", answer, "\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment