| library_name: transformers | |
| pipeline_tag: text-generation | |
| tags: | |
| - mistral | |
| - text-to-sql | |
| - sql | |
| language: | |
| - en | |
| license: apache-2.0 | |
| base_model: mistralai/Mistral-7B-v0.1 | |
| # Mistral-7B SQL (fine-tuned) | |
| Fine-tuned Mistral-7B for Text-to-SQL on `b-mc2/sql-create-context`. | |
| ## Usage (Transformers) | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| model_id = "kirankotha/mistral7b-sql-model" | |
| tok = AutoTokenizer.from_pretrained(model_id) | |
| mdl = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16) | |
| prompt = ( | |
| "You are a text-to-SQL model. | |
| " | |
| "### Input: | |
| " | |
| "Which product has the highest price? | |
| " | |
| "### Context: | |
| " | |
| "CREATE TABLE products (id INTEGER, name TEXT, price REAL) | |
| " | |
| "### Response: | |
| " | |
| ) | |
| ids = tok(prompt, return_tensors="pt").to(mdl.device) | |
| out = mdl.generate(**ids, max_new_tokens=100, do_sample=False, pad_token_id=tok.pad_token_id) | |
| print(tok.decode(out[0], skip_special_tokens=True)) | |