Created
January 25, 2023 09:02
-
-
Save nikitastaf1996/f859d4e9f14dc0e3115268126dc2d8dd to your computer and use it in GitHub Desktop.
Result
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/nikitastaf1996/f859d4e9f14dc0e3115268126dc2d8dd/result.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "!git clone https://github.com/HazyResearch/H3" | |
| ], | |
| "metadata": { | |
| "id": "63JZqoxBtNr7" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "%cd H3" | |
| ], | |
| "metadata": { | |
| "id": "S0wJ1rl6HkMW" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "!wget https://huggingface.co/danfu09/H3-1.3B/resolve/main/model.pt" | |
| ], | |
| "metadata": { | |
| "id": "2x1UKWPluUQ-" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "!pip install torch einops flash-attn transformers omegaconf pykeops pytorch-lightning numba" | |
| ], | |
| "metadata": { | |
| "id": "XTIr1kDmvVsx" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from typing import Optional\n", | |
| "\n", | |
| "import torch\n", | |
| "import torch.nn.functional as F\n", | |
| "\n", | |
| "from einops import rearrange\n", | |
| "\n", | |
| "from transformers import GPT2Tokenizer, GPT2Config, GPT2LMHeadModel\n", | |
| "\n", | |
| "from src.models.ssm.h3 import H3\n", | |
| "from src.models.ssm_seq import SSMLMHeadModel\n", | |
| "\n", | |
| "from flash_attn.utils.generation import InferenceParams\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "device = 'cuda'\n", | |
| "dtype = torch.float16\n", | |
| "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", | |
| "\n", | |
| "\n", | |
| "d_model = 2048\n", | |
| "n_layer = 24\n", | |
| "ssm_cfg = dict(mode='diag', measure='diag-lin')\n", | |
| "attn_layer_idx = [8,16]\n", | |
| "attn_cfg = dict(num_heads=16)\n", | |
| "model = SSMLMHeadModel(d_model, n_layer=n_layer, d_inner=4 * d_model, vocab_size=len(tokenizer),\n", | |
| " ssm_cfg=ssm_cfg, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg,\n", | |
| " pad_vocab_size_multiple=8).to(device=device)\n", | |
| "if \"/content/H3/model.pt\" is not None:\n", | |
| " state_dict = torch.load(\"/content/H3/model.pt\", map_location=device)\n", | |
| " if 'pytorch-lightning_version' in state_dict:\n", | |
| " state_dict = {k[len('model.'):]: v for k, v in state_dict['state_dict'].items()\n", | |
| " if k.startswith('model.')}\n", | |
| " model.load_state_dict(state_dict)\n", | |
| "model.eval()\n", | |
| "# Only cast the nn.Linear parameters to dtype, the SSM params stay in fp32\n", | |
| "# Pytorch lacks support for complex32 (i.e. complex<float16>) and complex<bfloat16>.\n", | |
| "for name, module in model.named_modules():\n", | |
| " if isinstance(module, (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)):\n", | |
| " module.to(dtype=dtype)" | |
| ], | |
| "metadata": { | |
| "id": "-25ZoIjXMNV4" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "#import requests\n", | |
| "#url = \"https://www.google.com/search?q=python+make+a+request\"\n", | |
| "\n", | |
| "#response = requests.get(url)\n", | |
| "#html_content = response.text\n", | |
| "#torch.random.manual_seed(0);\n", | |
| "\n", | |
| "prompt = \"Please write an essay about birds\" #@param {type: \"string\"}\n", | |
| "#prompt = html_content + prompt\n", | |
| "input_ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device)\n", | |
| "\n", | |
| "# Generate text\n", | |
| "gen_len = 1024\n", | |
| "max_length = input_ids.shape[1] + gen_len\n", | |
| "generated_text = model.generate(input_ids=input_ids, max_length=max_length,\n", | |
| " return_dict_in_generate=True, output_scores=True, timing=False)\n", | |
| "output_ids = generated_text[0][0]\n", | |
| "\n", | |
| "# Decode the output\n", | |
| "output_tokens = tokenizer.decode(output_ids)\n", | |
| "\n", | |
| "# Print the output\n", | |
| "print(output_tokens)" | |
| ], | |
| "metadata": { | |
| "id": "cAp0_p4bPn1z" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "private_outputs": true, | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "name": "python3" | |
| }, | |
| "accelerator": "GPU", | |
| "gpuClass": "standard" | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment