Last active
October 16, 2024 23:11
-
-
Save eric-humane/bc4a8853fe29a95305b5f22feec3d74c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| const swarm = new Swarm(); | |
| const agentB = new Agent({ | |
| name: "Spanish Agent", | |
| instructions: "You only speak Spanish.", | |
| }); | |
| const agentA = new Agent({ | |
| name: "English Agent", | |
| instructions: "You only speak English.", | |
| tools: [ | |
| Tool.TransferToAgent( | |
| agentB, | |
| "Transfer spanish speaking users immediately.", | |
| ), | |
| ], | |
| }); | |
| try { | |
| const result = await swarm.run( | |
| agentA, | |
| [{ role: "user", content: "Hola. ¿Como estás?" }], | |
| {}, | |
| ); | |
| console.log(result); | |
| } catch (error) { | |
| console.log(error); | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "name": "swarm", | |
| "version": "0.0.1", | |
| "private": true, | |
| "dependencies": { | |
| "openai": "^4.67.3", | |
| "zod": "^3.23.8", | |
| "zod-to-openai-tool": "^0.13.1" | |
| }, | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import { t, createTools } from "zod-to-openai-tool"; | |
| import { z } from "zod"; | |
| import { | |
| ChatCompletionCreateParams, | |
| ChatCompletionMessageParam, | |
| ChatCompletionMessageToolCall, | |
| } from "openai/resources"; | |
| import { OpenAI } from "openai"; | |
| interface DefaultContext { | |
| [key: string]: unknown; | |
| } | |
| export class Tool { | |
| name: string; | |
| description: string; | |
| input?: z.AnyZodObject; | |
| func: AgentFunction; | |
| constructor(params: { | |
| name: string; | |
| description: string; | |
| input?: z.AnyZodObject; | |
| func: AgentFunction; | |
| }) { | |
| this.name = params.name; | |
| this.description = params.description; | |
| this.input = params.input; | |
| this.func = params.func; | |
| } | |
| static TransferToAgent = (agent: Agent, description?: string) => { | |
| return new Tool({ | |
| name: `transfer_to_${agent.name.toLowerCase().replaceAll(" ", "_")}`, | |
| description, | |
| func: () => { | |
| return agent; | |
| }, | |
| }); | |
| }; | |
| } | |
| export class Agent { | |
| static DEFAULT_MODEL = "gpt-4o"; | |
| static DEFAULT_INSTRUCTIONS = "You are a helpful assistant."; | |
| name: string; | |
| model: string; | |
| instructions: string | ((input: Record<string, any>) => string); | |
| tools: Tool[]; | |
| toolChoice?: string; | |
| parallelToolCalls: boolean; | |
| constructor({ | |
| name, | |
| model = Agent.DEFAULT_MODEL, | |
| instructions = Agent.DEFAULT_INSTRUCTIONS, | |
| tools = [], | |
| toolChoice, | |
| parallelToolCalls = true, | |
| }: { | |
| name: string; | |
| model?: string; | |
| instructions?: string | ((input: Record<string, any>) => string); | |
| tools?: Tool[]; | |
| toolChoice?: string; | |
| parallelToolCalls?: boolean; | |
| }) { | |
| this.name = name; | |
| this.model = model; | |
| this.instructions = instructions; | |
| this.tools = tools; | |
| this.toolChoice = toolChoice; | |
| this.parallelToolCalls = parallelToolCalls; | |
| } | |
| getInstructions(contextVariables: Record<string, any>): string { | |
| return typeof this.instructions === "function" | |
| ? this.instructions(contextVariables) | |
| : this.instructions; | |
| } | |
| } | |
| export class Result<TContext extends DefaultContext = DefaultContext> { | |
| value: string; | |
| agent?: Agent; | |
| contextVariables: TContext; | |
| constructor(params: { value: string; agent?: Agent; contextVariables: TContext }) { | |
| this.value = params.value; | |
| this.agent = params.agent; | |
| this.contextVariables = params.contextVariables; | |
| } | |
| } | |
| export class Response<TContext extends DefaultContext = DefaultContext> { | |
| messages: ChatCompletionMessageParam[]; | |
| agent?: Agent; | |
| contextVariables: TContext; | |
| constructor(params: { | |
| messages: ChatCompletionMessageParam[]; | |
| agent?: Agent; | |
| contextVariables: TContext; | |
| }) { | |
| this.messages = params.messages; | |
| this.agent = params.agent; | |
| this.contextVariables = params.contextVariables; | |
| } | |
| addMessage(message: ChatCompletionMessageParam) { | |
| this.messages.push(message); | |
| } | |
| updateContextVariables(newVars: Partial<TContext>) { | |
| this.contextVariables = { ...this.contextVariables, ...newVars }; | |
| } | |
| setAgent(agent: Agent) { | |
| this.agent = agent; | |
| } | |
| } | |
| type AgentFunction< | |
| Args = any, | |
| TContext extends DefaultContext = DefaultContext, | |
| ReturnType = string | Agent | unknown, | |
| > = (args: Args, ctx: TContext) => Promise<ReturnType> | ReturnType; | |
| export class Swarm<TContext extends DefaultContext = DefaultContext> { | |
| client: OpenAI; | |
| constructor(client = new OpenAI({ apiKey: process.env.OPENAI_API_KEY })) { | |
| this.client = client; | |
| } | |
| private getChatCompletion( | |
| agent: Agent, | |
| history: ChatCompletionMessageParam[], | |
| contextVariables: TContext, | |
| modelOverride?: string | |
| ) { | |
| console.log("Getting completion from agent: " + agent.name); | |
| const instructions = agent.getInstructions(contextVariables); | |
| const messages: ChatCompletionMessageParam[] = [ | |
| { role: "system", content: instructions }, | |
| ...history, | |
| ]; | |
| const tools = agent.tools.map(({ name, description, input }) => { | |
| const { | |
| tools: [tool], | |
| } = createTools({ | |
| [name]: t | |
| .input(input ?? z.object({})) | |
| .describe(description) | |
| .run(() => {}), | |
| }); | |
| return tool; | |
| }); | |
| const createParams: ChatCompletionCreateParams = { | |
| model: modelOverride ?? agent.model, | |
| messages, | |
| stream: false, | |
| ...(tools.length > 0 && { | |
| tools, | |
| parallel_tool_calls: agent.parallelToolCalls, | |
| }), | |
| ...(agent.toolChoice && { | |
| tool_choice: { type: "function", function: { name: agent.toolChoice } }, | |
| }), | |
| }; | |
| return this.client.chat.completions.create(createParams); | |
| } | |
| private async handleToolCalls( | |
| toolCalls: ChatCompletionMessageToolCall[], | |
| tools: Tool[], | |
| contextVariables: TContext | |
| ) { | |
| const partialResponse = new Response<TContext>({ | |
| messages: [], | |
| contextVariables: { ...contextVariables }, | |
| }); | |
| const functionMap = new Map(tools.map((func) => [func.name, func])); | |
| // Process tool calls in parallel and collect results | |
| const toolCallResults = await Promise.all( | |
| toolCalls.map(async (toolCall) => { | |
| const tool_call_id = toolCall.id; | |
| const name = toolCall.function.name; | |
| const tool = functionMap.get(name); | |
| function msg(content: string): ChatCompletionMessageParam { | |
| return { | |
| role: "tool", | |
| tool_call_id, | |
| content, | |
| }; | |
| } | |
| if (!tool) { | |
| console.error(`Tool ${name} not found in function map.`); | |
| return { | |
| message: msg(`Error: Tool ${name} not found.`), | |
| }; | |
| } | |
| let args: Record<string, any>; | |
| try { | |
| args = JSON.parse(toolCall.function.arguments); | |
| } catch (error: any) { | |
| console.error(`Failed to parse arguments for tool ${name}:`, error); | |
| return { | |
| message: msg(`Error: Invalid arguments for tool ${name}. Details: ${error.message}`), | |
| }; | |
| } | |
| console.log(`Processing tool call: ${name} with arguments`, args); | |
| try { | |
| const result = this.handleResult(await tool.func(args, contextVariables)); | |
| return { | |
| message: msg(result.value), | |
| contextVariables: result.contextVariables, | |
| agent: result.agent, | |
| }; | |
| } catch (error: any) { | |
| console.error(`Error executing tool ${name}:`, error); | |
| return { | |
| message: msg(`Error executing tool ${name}. Details: ${error.message}`), | |
| }; | |
| } | |
| }) | |
| ); | |
| // Combine results after all tool calls have completed | |
| for (const result of toolCallResults) { | |
| if (result.message) { | |
| partialResponse.addMessage(result.message); | |
| } | |
| partialResponse.updateContextVariables({ | |
| ...partialResponse.contextVariables, | |
| ...result.contextVariables, | |
| }); | |
| if (result.agent) { | |
| partialResponse.setAgent(result.agent); | |
| } | |
| } | |
| return partialResponse; | |
| } | |
| private handleResult(result: Result | Agent | unknown): Result<TContext> { | |
| if (result instanceof Result) { | |
| return result; | |
| } else if (result instanceof Agent) { | |
| return new Result<TContext>({ | |
| value: JSON.stringify({ assistant: result.name }), | |
| agent: result, | |
| contextVariables: {} as TContext, | |
| }); | |
| } else { | |
| return new Result<TContext>({ | |
| value: JSON.stringify(result), | |
| contextVariables: {} as TContext, | |
| }); | |
| } | |
| } | |
| public async run( | |
| agent: Agent, | |
| messages: ChatCompletionMessageParam[], | |
| contextVariables: TContext, | |
| maxTurns: number = 25, | |
| executeTools: boolean = true, | |
| modelOverride?: string, | |
| historyLimit: number = 100 | |
| ): Promise<Response<TContext>> { | |
| if (!Number.isInteger(maxTurns) || maxTurns <= 0) { | |
| throw new Error("maxTurns must be a positive integer"); | |
| } | |
| if (!Number.isInteger(historyLimit) || historyLimit <= 0) { | |
| throw new Error("historyLimit must be a positive integer"); | |
| } | |
| if (typeof contextVariables !== "object" || contextVariables === null) { | |
| throw new Error("contextVariables must be a non-null object"); | |
| } | |
| let activeAgent = agent; | |
| let history = [...messages]; | |
| const initLen = messages.length; | |
| while (history.length - initLen < maxTurns && activeAgent) { | |
| // Truncate history if it exceeds the limit | |
| if (history.length > historyLimit) { | |
| history = [...history.slice(0, 1), ...history.slice(-(historyLimit - 1))]; | |
| } | |
| const completion = await this.getChatCompletion( | |
| activeAgent, | |
| history, | |
| contextVariables, | |
| modelOverride | |
| ); | |
| const message = completion.choices[0].message; | |
| console.log("Received completion:", message); | |
| history.push(message); | |
| if (!message.tool_calls || !executeTools) { | |
| console.log("Ending turn."); | |
| break; | |
| } | |
| const partialResponse = await this.handleToolCalls( | |
| message.tool_calls, | |
| activeAgent.tools, | |
| contextVariables | |
| ); | |
| history = history.concat(partialResponse.messages); | |
| partialResponse.updateContextVariables(contextVariables); | |
| if (partialResponse.agent) { | |
| activeAgent = partialResponse.agent; | |
| } | |
| } | |
| return new Response({ | |
| messages: history.slice(initLen), | |
| agent: activeAgent, | |
| contextVariables, | |
| }); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment