Skip to content

Instantly share code, notes, and snippets.

@nikitastaf1996
Created January 25, 2023 09:02
Show Gist options
  • Select an option

  • Save nikitastaf1996/f859d4e9f14dc0e3115268126dc2d8dd to your computer and use it in GitHub Desktop.

Select an option

Save nikitastaf1996/f859d4e9f14dc0e3115268126dc2d8dd to your computer and use it in GitHub Desktop.
Result
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/nikitastaf1996/f859d4e9f14dc0e3115268126dc2d8dd/result.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!git clone https://github.com/HazyResearch/H3"
],
"metadata": {
"id": "63JZqoxBtNr7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%cd H3"
],
"metadata": {
"id": "S0wJ1rl6HkMW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!wget https://huggingface.co/danfu09/H3-1.3B/resolve/main/model.pt"
],
"metadata": {
"id": "2x1UKWPluUQ-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install torch einops flash-attn transformers omegaconf pykeops pytorch-lightning numba"
],
"metadata": {
"id": "XTIr1kDmvVsx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from typing import Optional\n",
"\n",
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"from einops import rearrange\n",
"\n",
"from transformers import GPT2Tokenizer, GPT2Config, GPT2LMHeadModel\n",
"\n",
"from src.models.ssm.h3 import H3\n",
"from src.models.ssm_seq import SSMLMHeadModel\n",
"\n",
"from flash_attn.utils.generation import InferenceParams\n",
"\n",
"\n",
"\n",
"device = 'cuda'\n",
"dtype = torch.float16\n",
"tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
"\n",
"\n",
"d_model = 2048\n",
"n_layer = 24\n",
"ssm_cfg = dict(mode='diag', measure='diag-lin')\n",
"attn_layer_idx = [8,16]\n",
"attn_cfg = dict(num_heads=16)\n",
"model = SSMLMHeadModel(d_model, n_layer=n_layer, d_inner=4 * d_model, vocab_size=len(tokenizer),\n",
" ssm_cfg=ssm_cfg, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg,\n",
" pad_vocab_size_multiple=8).to(device=device)\n",
"if \"/content/H3/model.pt\" is not None:\n",
" state_dict = torch.load(\"/content/H3/model.pt\", map_location=device)\n",
" if 'pytorch-lightning_version' in state_dict:\n",
" state_dict = {k[len('model.'):]: v for k, v in state_dict['state_dict'].items()\n",
" if k.startswith('model.')}\n",
" model.load_state_dict(state_dict)\n",
"model.eval()\n",
"# Only cast the nn.Linear parameters to dtype, the SSM params stay in fp32\n",
"# Pytorch lacks support for complex32 (i.e. complex<float16>) and complex<bfloat16>.\n",
"for name, module in model.named_modules():\n",
" if isinstance(module, (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)):\n",
" module.to(dtype=dtype)"
],
"metadata": {
"id": "-25ZoIjXMNV4"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#import requests\n",
"#url = \"https://www.google.com/search?q=python+make+a+request\"\n",
"\n",
"#response = requests.get(url)\n",
"#html_content = response.text\n",
"#torch.random.manual_seed(0);\n",
"\n",
"prompt = \"Please write an essay about birds\" #@param {type: \"string\"}\n",
"#prompt = html_content + prompt\n",
"input_ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device)\n",
"\n",
"# Generate text\n",
"gen_len = 1024\n",
"max_length = input_ids.shape[1] + gen_len\n",
"generated_text = model.generate(input_ids=input_ids, max_length=max_length,\n",
" return_dict_in_generate=True, output_scores=True, timing=False)\n",
"output_ids = generated_text[0][0]\n",
"\n",
"# Decode the output\n",
"output_tokens = tokenizer.decode(output_ids)\n",
"\n",
"# Print the output\n",
"print(output_tokens)"
],
"metadata": {
"id": "cAp0_p4bPn1z"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"colab": {
"provenance": [],
"private_outputs": true,
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment