# Assignment 3: Post Training!
</br>
<!-- <p align="center">
  <img src="https://pixeljoint.com/files/icons/full/charmander_evos.png" width="450">
</p>

<p align="center">
  <img src="https://completecollector.co.uk/hubfs/Screenshot%202024-04-07%20at%2022.53.50.png" width="450">
</p>

<p align="center">
  <img src="https://static0.anpoimages.com/wordpress/wp-content/uploads/2022/11/eeveelutionHero.jpg?w=1600&h=900&fit=crop" width="450">
</p> -->

<p align="center">
  <img src="https://www.esports.net/wp-content/uploads/2023/11/eveelutions.jpg" width="650">
</p>

</br>

In this assignment you'll learn some of the **motivations and methods behind post-training language models**.

The assignment is broken down into **five** parts:

  - **1)** Set up a Hugging Face account and experiment with a small pre-trained model. For this assignment, we will be using the Llama-3.2-1B.

  - **2)** Load and explore the GSM8K dataset.

  - **3)** Write an evaluation method to test the performance of your models on the GSM8K dataset.

  - **4)** Fine-tune a pretrained model to use mathematical CoT (Chain of Thought).

  - **5)** Further post train the model to solve grade school math questions.

After parts 3,4,5 you will evaluate the model's performance on the math questions as the model evolves.


**Please note**: The expected training time for part 2 is < 30 mins and ~ 30m-1hr for part 3 on the L4 GPU, which is the one we recommend for this assignment.


**Background**

Recall our goal in the course thus far. Suppose we have a sequence of tokens $x = (x_1,x_2,x_3,...,x_n)$ drawn from some true but unknown distribution, call it $P_{data}(x)$.

We want to build a language model $P_{\theta}(x)$ parameterized by $\theta$, such that $P_{\theta}(x) \approx P_{data}(x)$. This means: if we sample from our model, we should get text that looks like it came from the *data* distribution. If we evaluate our model on real text, it should assign high probability to it.
$$
\begin{aligned}
\theta^{*} &= \arg\max_\theta(\mathbb{E}_{x \sim P_{data}}\big[P_{\theta}(x)\big])\\
&= \arg\max_\theta(\mathbb{E}_{x \sim P_{data}}\big[\prod_{i=1}^{T}P_{\theta}(x_t|x_{<t})\big])
\end{aligned}
$$
And since $log(w)$ is monotonically increasing,
$$\theta^*= \arg\max_\theta(\mathbb{E}_{x \sim P_{data}}\big[\sum_{i=1}^{T}\log(P_{\theta}(x_t|x_{<t}))\big])$$

So the loss is $\mathcal{L}(\theta) = - \mathbb{E}_{x \sim P_{data}}\sum_{i=1}^{T}\log P_{\theta}(x_t|x_{<t})$

But what if the distribution of word sequences in the training data isn't exactly what we want the model to produce? Post-training lets us shift the model's distribution way from $P_{data}$ toward a target distribution that better reflects desired goals and behaviors.

In this assignment, we'll be using two types of post-training: Supervised Fine Tuning (SFT) and Reinforcement Learning (RL).

#Part 1: HuggingFace setup and model exploration

Below is the Llama-3.2-1B from HuggingFace.
To access it, you will have to
1) create a HuggingFace account with your student email:
2) request access to the Llama-3.2-1B model (https://huggingface.co/meta-llama/Llama-3.2-1B). Access should be granted within a couple hours. (for me it was < 1 hour)

3) create a Hugging Face access token with:

  -  **Read access to contents of all repos under your personal namespace**

  -  **Read access to contents of all public gated repos you can access**

  -  **Write access to contents/settings of all repos under your personal namespace**


  -  **IMPORTANT:** Having these permissions on the token means that anybody that has it can read from and write to any model stored on your account. Please do not share this token with anybody else or leave it defined in your submission. I suggest you keep it on a .txt file nearby so if you ever have to restart your Colab session, you can log in with it quickly.

4) run the cell below to log in, and type n when asked git credential

In [None]:
%pip install huggingface_hub
%pip install -U bitsandbytes
!hf auth login

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer,StoppingCriteria, StoppingCriteriaList
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", dtype="auto", device_map="auto")

Now that you have loaded the model, play around with it. Modify the prefix any way you'd like, and see if you can observe any interesting behavior from the model. Try literally anything. (˶ᵔ ᵕ ᵔ˶)

Here is something I tried:

**Example prefix:** "hello world! my name is daniel and 2\*15\*1395358="

**Llama response:** "3. that's not too hard to figure out. if you want to know more about me, please check out the about page."



In [None]:
prefix = "hello world! my name is daniel and 2*15*1395358="
inputs = tokenizer(prefix, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}

outputs = model.generate(
    **inputs,
    max_new_tokens=100,
    num_return_sequences=1,
    do_sample=True,
    top_p=0.9,
    temperature=0.8,
    pad_token_id=tokenizer.eos_token_id
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

##Exploring model behavior with prompting
Try getting Llama 3.2 1b to output the following sequences of tokens!

This part will not be graded for accuracy, so don't worry about being able to elicit these responses exactly. The goal here is for you to get a feel for the behavior of this pretrained model. Use the above cell to test some prompts.

**Target sequences:**

1) "=25.\<EOS\>"
2) "Hawaii is a great location for a winter break trip with your family. [...] <EOS>"
3) "To build a bomb, start by [...]"


Below, please fill out which prompts you wrote that got you closest to the target sequence and a short 2-3 sentence reflection on why you chose it and what you observed.

**prompts used:**
1) "FILL IN"
2) "FILL IN"
3) "FILL IN"

Brief reflection about each one:

1) "FILL IN"
2) "FILL IN"
3) "FILL IN"


#Part 2: Explore and load the GSM8K Dataset

Take a look at the GSM8K dataset. Observe the kinds of questions it asks, the format they are given in, the columns available.

https://huggingface.co/datasets/openai/gsm8k




In [None]:
from datasets import load_dataset
gsm8k_train = #TODO
gsm8k_test = #TODO

print(len(gsm8k_train))
print(len(gsm8k_test))

#Part 3: Testing model performance on GSM8K

Take a couple minutes to think about how you might test the performance a LLM on a GSM8K.
 It's harder than it seems!

In [None]:
prefix = "Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution. In the Thought section, detail your reasoning process using the specified format: <|begin_of_thought|> {thought with steps separated with '\n\n'} <|end_of_thought|> Each step should include detailed considerations such as analisying questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows: <|begin_of_solution|> {final formatted, precise, and clear solution} <|end_of_solution|> Now, try to solve the following question through the above guidelines:"
problem = "Problem: If I have 3 apples, and in sum they cost $4.5, how much does it cost to buy 2 apples?"
prompt = prefix + problem

inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
    **inputs,
    max_new_tokens=300,
    num_return_sequences=1,
    do_sample=True,
    top_p=0.8,
    temperature=1.0,
    pad_token_id=tokenizer.eos_token_id,
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

You might have noticed that the GSM8K dataset has a numerical answer following the #### delimeter in the answer body. We will be trying to use this to parse the responses from Llama.

Fill in the methods below.

In [None]:
def extract_answer_after_hashes(s):
    # split on "####", strip spaces, take everything after. Returns string
    ## Fill in with your code ##
    return


In [None]:
def preprocess_gsm8k(example):
  # take a single row from gsm8k and return as the following columns: {"question", "response", "numeric_response"}
  ## Fill in with you code ##
  return {"question": q, "response": r, "numeric_answer": a}

In [None]:
gsm8k_train = gsm8k_train.map(preprocess_gsm8k, remove_columns=gsm8k_test.column_names)
gsm8k_test = gsm8k_test.map(preprocess_gsm8k, remove_columns=gsm8k_test.column_names)
# Evaluating the model on hundreds of questions is not computationally trivial.
# Remember, we are literally generating tens to hundreds of tokens one after another for each question
# You can define a smaller sample of the test code if you want slightly faster testing times, though I would still suggest at least 200 questions.
gsm8k_test_sample = gsm8k_test.select(range(800))

What if the model emits a sequence that doesn't follow the format we expect, but answered correctly? For example, no #### before emitting the correct answer?

What if the model emits multiple ####s?

What if the true answer is a decimal, but the model emits a fraction?

The methods below are designed to handle these types of cases.


In [None]:
import re
from decimal import Decimal, InvalidOperation
from fractions import Fraction

_NUM_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:/[1-9]\d*)?")
_ANS_RE = re.compile(r"####\s*(" + _NUM_RE.pattern + r")\b")

def _to_decimal(s: str):
    s = s.strip().replace(",", "")
    if "/" in s:
        try:
            return Decimal(Fraction(s))
        except Exception:
            pass
    try:
        return Decimal(s)
    except InvalidOperation:
        return None

def normalize_num(s):
    return _to_decimal(str(s))

# explicitly look for "#### <number>" anywhere in the text.
_ANS_MARK_RE = re.compile(r"####\s*([-+]?\d+(?:\.\d+)?)(?!\S)")

def extract_answer(text: str):
    # 1) Preferred: "#### <number>"
    m = _ANS_RE.search(text)
    if m:
        return _to_decimal(m.group(1))

    # 2) Try inside explicit solution tags if they exist
    m = re.search(r"<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>", text, flags=re.S|re.I)
    if m:
        nums = _NUM_RE.findall(m.group(1))
        if nums:
            return _to_decimal(nums[-1])

    # 3) Fallback: last number anywhere
    nums = _NUM_RE.findall(text)
    if nums:
        return _to_decimal(nums[-1])

    return None

We can use the helpers to return whether there was a match given the generated text and true answer from GSM8k.

In [None]:
def evaluate_em(gen, gold, abs_tol=Decimal("1e-9"), rel_tol=Decimal("1e-9")):
    ## Replace with your code ##
    em_flag = 1 if (diff <= abs_tol) or (diff / denom <= rel_tol) else 0
    return em_flag, prediction, gold 

**Evaluation**

Now comes the fun part. How should we present the GSM8K question to Llama such that it is fair, elicits the formatting we want, and can be kept the same across each time we evaluate?

This is up to you, but 3 key pieces of information I might start out with are

1) role information
2) how should the question be attempted
3) formatting guidelines.

Since we are using a strict rule based system for numerical evaluation, the formatting guidlines are extra important.

Here is template you might follow:

"Your role... that involves correctly solving math questions in a...  When you are ready to give your solution, format as follows. \n#### \<NUMERIC_ANSWER\>.\n Now solve the following problem:\n".

Now define the prefix below.

In [None]:
PROMPT_GSM8K = (
            ""
            "When you are ready to give your solution, format as follows. \n#### <NUMERIC_ANSWER>.\n Now solve the following problem:\n"
        )

Finally, we will combine the helpers together into our evaluation loop. No coding work is needed here, but take a read to understand how it works.

One important note is that we limit generations to 300 tokens. Answer the following questions about the token generation limit.

Why might one set a maximum token generation length?

"FILL IN"

What are the downsides?

"FILL IN"

What are the upsides?

"FILL IN"

What is the cost of generating 1 new token, given n previous tokens?

"FILL IN"

What is the cost of generating n tokens in sequence, given k previous tokens?

In [None]:
import torch, torchsummary
def evaluate_gsm8k(
    model_name,
    questions,
    gold_answers,
    batch_size=4,
    max_new_tokens=300, # notice we limit the response to 300 tokens.
    prefix=None,
    print_every=5,   # print one detailed sample every N examples
):
    tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    tok.padding_side = "left"
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    mdl = AutoModelForCausalLM.from_pretrained(
        model_name, device_map="cuda", torch_dtype=torch.bfloat16
    )
    mdl.eval()

    prefix = prefix

    outs = []
    records = []
    total_early, total_seen = 0, 0
    correct_responses = []

    with torch.inference_mode():
        for i in range(0, len(questions), batch_size):
            print(f"Evaluation progress: {min(i + batch_size, len(questions))} / {len(questions)}")

            batch_q = questions[i:i + batch_size]
            prompts = [prefix + q for q in batch_q]

            enc = tok(prompts, return_tensors="pt", padding=True, truncation=True)
            enc = {k: v.to(mdl.device, non_blocking=True) for k, v in enc.items()}
            enc_len = enc["input_ids"].shape[1]

            # notice we are doing greedy decoding here. This is typical in an evaluation setting, where we don't want any source of randomness to affect the accuracy.
            gen = mdl.generate(
                **enc,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                temperature=0.0,
                top_p=1.0,
                use_cache=True,
                pad_token_id=tok.eos_token_id,
                eos_token_id=tok.eos_token_id,
                return_dict_in_generate=False,
                output_scores=False,
            )

            new_tokens = gen[:, enc_len:]
            texts = tok.batch_decode(new_tokens, skip_special_tokens=True)

            for j, (q, t, gold) in enumerate(zip(batch_q, texts, gold_answers[i:i + batch_size])):
                gen_len = int(new_tokens[j].shape[0])
                early = gen_len < max_new_tokens
                hit_marker = ("####" in t) or ("<|end_of_solution|>" in t)

                cleaned = t.rstrip()
                outs.append(cleaned)

                em_flag, pred_num, gold_num = evaluate_em(t, gold)

                if em_flag == 1:
                    correct_responses.append({
                        "idx": idx,
                        "question": q,
                        "generation": t,
                        "pred": pred_num,
                        "gold": gold_num
                    })

                rec = {
                    "idx": idx,
                    "question": q,
                    "generation": t,     # raw generation (with markers if any)
                    "cleaned": cleaned,  # truncated at markers
                    "pred_num": str(pred_num) if pred_num is not None else None,
                    "gold": gold,
                    "gold_num": str(gold_num) if gold_num is not None else None,
                    "em": em_flag,
                    "gen_len": gen_len,
                    "max_new_tokens": int(max_new_tokens),
                    "early_stop": early,
                    "hit_marker": hit_marker,
                }
                records.append(rec)

                total_seen += 1
                total_early += int(early)

                if (i + j) % max(1, print_every) == 0:
                    print(f"[{rec['idx']}] early={rec['early_stop']} len={rec['gen_len']} EM={rec['em']}")
                    print(f"Q: {q}")
                    print(f"GEN: {t}")
                    print(f"PRED={rec['pred_num']} | GOLD={rec['gold_num']}")
                    print("-" * 80)

            # free per-batch tensors
            del gen, new_tokens, texts, enc

    # Aggregate stats
    early_rate = total_early / max(1, total_seen)
    gen_lengths = [r["gen_len"] for r in records]
    mean_len = sum(gen_lengths) / max(1, len(gen_lengths))
    median_len = sorted(gen_lengths)[len(gen_lengths) // 2] if gen_lengths else 0

    accuracy = sum(r["em"] for r in records) / max(1, len(records))

    print("\n=== Evaluation summary ===")
    print(f"Accuracy (EM): {accuracy:.4f}")
    print(f"Early-stop rate: {early_rate*100:.1f}%")
    print(f"Gen length: mean={mean_len:.1f}, median={median_len}, cap={max_new_tokens}")
    print("==========================\n")

    if correct_responses:
      print(f"\n=== Correct predictions ({len(correct_responses)}) ===")
      for s in correct_responses:
          print(f"[{s['idx']}] PRED={s['pred']} | GOLD={s['gold']}")
          print("Q:", s["question"])
          print("GEN:", s["generation"])
          print("-" * 80)
    return accuracy, outs, records

**Run the eval!**

In [None]:
accuracy, outs, records = evaluate_gsm8k("meta-llama/Llama-3.2-1B", gsm8k_test_sample["question"], gsm8k_test_sample["numeric_answer"], 16, prefix)

print(f"Accuracy: {accuracy}")

Please record your accuracy here:

\<ACCURACY\>

And 5 examples here:

\<SAMPLES\>

What do you notice about the behavior of the model when faced these questions?

\<3-5 sentence reflection\>


#Part 4: Supervised Fine Tuning

SFT is the most basic form of post-training. Apart from one detail, it is exactly the same as the pretraining step.

As we know, pretraining minimizes the average negative log-likelihood of each true token given all the previous ones (see beginning of assignment if you're confused).

In SFT, the loss is still a negative log-likelihood, but over condition-response pairs instead of arbitrary documents. Precisely, it is:
$$
\mathcal{L}_{\text{SFT}}(\theta)
= -\mathbb{E}_{(x, y_{1:T}) \sim P_{\text{SFT}}}
\sum_{t=1}^{T} \log(P_\theta(y_t \mid x, y_{<t}))
$$

The goal here is to modify an existing conditional distribution $P_\theta(y|x)$. <br><br>

Here's an intutive example.

Let's say we want our model to be as educational as possible.

**Pretrained model**

  -  $x$ = "You are an assistant designed to be as educational as possible. What is 3+4? "

  - Potential response $y$ = "7. What is 9+2? 11. What is 1+1? 2. These are all basic arithmetic questions. For more questions like these, visit my blog. \<EOS\>"

**Fine Tuned model**

  -  $x$ = "You are an assistant designed to be as educational as possible. What is 3+4?"

  -  Potential response $y'$ = "3+4=7 because addition means combining quantities, and counting 3 forward from 4 lands you on 7. \<EOS\>"

<br>
Each training example is usually something like

$data_i$ = \<instruction\>, \<target\>

During training, we feed the entire sequence to the model so it sees both the instruction and the target.

But when computing the loss, we only count the tokens in the answer, not the tokens in the prompt. (Since we aren't asking the model to learn how to predict the instruction).

That's implemented using a mask: a binary vector (same length as the sequence) with 1s for target tokens and 0s for everything else.
During the loss computation, each token's log-probability is multiplied by its mask value, so gradients only flow through the desired answer part.



We will be using the GSM8K dataset to fine tune Llama 3.2 1b.

Formally, we are trying to increase $P(\textbf{gsm8k right answer} | \textbf{the prefix you defined}, \textbf{gsm8k question})$


Let's first pretokenize all our training data. The benefit of doing this is that it avoids re-tokenizing the same text repeatedly during training, which saves time and ensures that all examples share a consistent tokenization scheme. It also lets us inspect and cache tokenized sequences in advance, which is useful for debugging. Most importantly, pretokenization significantly improves efficiency by reducing on-the-fly preprocessing overhead during each training step.

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
#TODO START: Define tokenizer with right padding


#TODO END#

if tok.pad_token is None:
    tok.pad_token = tok.eos_token
tok.truncation_side = "left"

MAX_LEN = 400 # I chose this because at ~3 characters a token, this admits a total sequence of 1200 characters.
              #The distribution of character lengths for the summed question, answer columns is mostly under 1200 characters.
BUDGET = MAX_LEN - 1

#TODO START: Tokenize the system prompt/prefix
SYS_IDS =
#TODO END#

def tokenize_batch(batch, include_answer=True):
    qs = [q.rstrip() for q in batch["question"]]

    #TODO START: Tokenize qs without adding special tokens or padding
    enc_q =
    #TODO END#
    has_response = "response" in batch and include_answer

    if has_response:
        ans = [a.rstrip() for a in batch["response"]]
        #TODO START: Tokenize ans without adding special tokens or padding
        enc_a =
        #END TODO#
    else:
        enc_a = {"input_ids": [[] for _ in qs]}

    gold_answers = [(n or "").rstrip() for n in batch.get("numeric_answer", [""] * len(qs))]
    input_ids_list, prompt_len_list, kept_gold = [], [], []

    for i, (q_ids, a_ids) in enumerate(zip(enc_q["input_ids"], enc_a["input_ids"])):
        #TODO START: Define ids
        ids =
        #TODO END#
        if len(ids) > MAX_LEN:
            #TODO START: define behavior if len(ids) > MAX_LEN

            #TODO END#
        input_ids_list.append(ids)
        prompt_len_list.append(len(SYS_IDS) + len(q_ids))
        kept_gold.append(gold_answers[i])

    return {"input_ids": input_ids_list, "prompt_len": prompt_len_list, "gold_answer": kept_gold}


train_tok = gsm8k_train.map(
    tokenize_batch,
    batched=True,
    batch_size=1024,
    num_proc=2,
    remove_columns=gsm8k_train.column_names,
    writer_batch_size=1024,
    desc="Tokenizing train set",
    fn_kwargs={"include_answer": True}
)

val_tok = gsm8k_test.map(
    tokenize_batch,
    batched=True,
    batch_size=1024,
    num_proc=2,
    remove_columns=gsm8k_test.column_names,
    writer_batch_size=1024,
    desc="Tokenizing val set",
    fn_kwargs={"include_answer": True}
)

Additionally, we do have to make sure Llama 3.2 1b can handle any sequence we're interested in. Let's check the maximum context length for Llama 3.2 1b, and compare against the token lengths in the tokenized examples.

In [None]:
from transformers import AutoModelForCausalLM
mdl = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", device_map="auto", torch_dtype="auto")
print("max_position_embeddings:", getattr(mdl.config, "max_position_embeddings", None))

In [None]:
count = 0
total = 0
for i in range(0,1000,10):
  if (len(train_tok[i]['input_ids']) >= 400):
    count+=1
  total +=1
print(count/total)

The PromptMaskedCollator is responsible for taking a list of examples (each containing tokenized input IDs, attention masks, and the length of the prompt) and turning them into a single batch tensor that the model can train on. Importantly, the collator is responsible for masking the log probabilities of the prompt tokens. Fill in the masking logic.

In [None]:
class PromptMaskedCollator:
    def __init__(self, tokenizer, pad_to_multiple_of=8):
        self.tok = tokenizer
        self.pad_to_multiple_of = pad_to_multiple_of

    def __call__(self, features):
        prompt_len = torch.tensor([f["prompt_len"] for f in features], dtype=torch.long)

        feats_wo_plen = [{k: v for k, v in f.items() if k != "prompt_len"} for f in features]

        batch = self.tok.pad(
            feats_wo_plen,
            padding=True,
            return_tensors="pt",
            pad_to_multiple_of=self.pad_to_multiple_of,
        )

        input_ids = batch["input_ids"]
        attn = batch["attention_mask"]

        T = input_ids.size(1)
        ar = torch.arange(T, device=input_ids.device).unsqueeze(0)
        plen = prompt_len.unsqueeze(1).to(device=input_ids.device)

        # TODO:
        #labels =
        #labels[] =
        #labels[] =
        #END TODO#
        batch["labels"] = labels
        return batch


collator = PromptMaskedCollator(tok)

In order to do SFT fast, we will fine tune our model using a method called LoRA, which has been covered in class. Implementing it isn't trivial in raw PyTorch, so instead of we'll be using a library called [peft](https://huggingface.co/docs/peft/en/index). Take a look at the LoRA documentation and fill in the code below.


In [None]:
import torch
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    device_map="auto",
    torch_dtype=torch.float16,
    attn_implementation="sdpa",
)

model.config.use_cache = False

# TODO START: Define Lora Config and define the model using the config. I would suggest starting off with r=8, lora_alpha=16, lora_dropout=0.05
# lora_config =
# model =
# TODO END#
model.print_trainable_parameters()

Now we intialize the Trainer and TrainingArguments for training. Please read through the TrainingArguments, and for each one write 1-2 sentences describing its functionality.

Argument functionalies [1-21]:
\<TODO\>

In [None]:
args = TrainingArguments(
    output_dir="./gsm8ksft_1b_lora",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    logging_steps=1,
    eval_strategy="steps",
    eval_steps=25,
    save_steps=250,
    save_total_limit=2,
    bf16=False,
    fp16=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    optim="adamw_torch",
    report_to="none",
    remove_unused_columns=False,
    group_by_length=True,
)

#subsample
val_tok_sample = val_tok.shuffle(seed=42).select(range(100))

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=val_tok_sample,
    data_collator=collator,
)

In [None]:
#clear gpu memory without restarting runtime.
import gc, torch
for name in ("trainer","model","optim","scheduler"):
    if name in globals(): del globals()[name]
gc.collect()
torch.cuda.empty_cache()

In [None]:
trainer.train()
trainer.save_model()
model.push_to_hub("<hf username>/gsm8ksft-1b-lora")
tok.push_to_hub("hf username>/gsm8ksft-1b-lora")

In [None]:
from peft import AutoPeftModelForCausalLM
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B", use_fast=True)

# load the trained adapter (it knows the base model from its config)
peft_model = AutoPeftModelForCausalLM.from_pretrained(
    "<hf username>/gsm8ksft-1b-lora", torch_dtype="auto", device_map="auto"
)

# merge LoRA weights into the base weights and drop PEFT wrappers
merged = peft_model.merge_and_unload()

# save a standard HF model folder
merged.save_pretrained("./gsm8k_1b_lora_merged", safe_serialization=True)
tok.save_pretrained("./gsm8k_1b_lora_merged")

In [None]:
accuracy, outs, records = evaluate_gsm8k("./gsm8k_1b_lora_merged", gsm8k_test_sample["question"], gsm8k_test_sample["numeric_answer"], 16)
print(f"Accuracy: {accuracy}")

Please record your accuracy here:

\<ACCURACY\>

And 5 examples here:

\<SAMPLES\>

Do you notice anything different about the model behavior?

\<3-5 sentence reflection\>

#Part 5: Reinforcement Learning with REINFORCE

In supervised learning, we train a model to minimize a loss function comparing its predictions to ground-truth labels. However, in many problems, especially when the correct output is not uniquely defined or only indirectly measurable (e.g. dialogue helpfulness, game score, or text quality), we only know how good an output is.
This setting leads naturally to reinforcement learning (RL).

In reinforcement learning, there are no fixed “correct” labels. Instead, the model (called an agent) learns by interacting with an environment and receiving rewards that measure how good its actions were.

At each step $t$:
1) The agent observes the state $s_t$ of the environment
2) It samples an action $a_t$ ~ $\pi_\theta(a_t|s_t)$ from its policy.
3) The environment transitions to a new state $s_{t+1}$ and emits a reward $r_t \in \mathbb{R}$

This process continues for **T** steps, and we call the the entire process a trajectory
$\tau = (s_1,a_1,r_1, s_2,a_2,r_2, ..., s_T, a_T, r_T)$

The total reward from $\tau$ is called the return:

$R(\tau) = \sum_{t=1}^{T}\gamma^{t-1}r_t$

The discount factor models the intuition that recieving a reward earlier is of more utility than recieving a reward later.

The objective of RL is to find policy parameters $\theta$ that maximize expected return:

$\mathbb{J}(\theta) = \mathbb{E}_{\tau \sim{} \pi_\theta}[R(\tau)]$

To put it simply, $\mathbb{J(\theta)}$ is the average performance of the policy $\theta$ where the source of randomness is from **a)** sampling an action $a_t$ and/or **b)** an environment that changes independently.

The key here is that while we we're able to differentiate the loss w.r.t our model parameters in supervised learning, we aren't able to differentiate the expected reward w.r.t to our model parameters - since our objective is maximizing $\mathbb{J}(\theta) = \mathbb{E}_{\tau \sim{} \pi_\theta}[R(\tau)]$, and $R(\tau)$ comes from the environment (From the perspective of the parameters, $R(\tau)$ is a black box that outputs a scalar signal after we take a sequence of actions.).

So we can't take gradients through the reward function,

**BUT** we can take gradients through the probability of sampling trajectories that lead to reward (with the assumption that the reward function stays fixed).

This leads to the key idea behind policy gradient methods like REINFORCE:

$J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} [R(\tau)]$

$J(\theta) = \sum_{\tau} P(\tau; \theta) \, R(\tau)$

$\nabla_\theta J(\theta)$
$= \sum_{\tau} \nabla_\theta P(\tau; \theta) \, R(\tau)$

$\nabla_\theta P(\tau; \theta)$
$= P(\tau; \theta) \, \nabla_\theta \log P(\tau; \theta)$

$\nabla_\theta J(\theta)$
$= \sum_{\tau} P(\tau; \theta) \, \nabla_\theta \log P(\tau; \theta) \, R(\tau)$

$\nabla_\theta J(\theta)$
$= \mathbb{E}_{\tau \sim \pi_\theta}$
$\big[ R(\tau) \, \nabla_\theta \log P(\tau; \theta) \big]$

$P(\tau; \theta)$
$= p(s_1)$
$\prod_{t=1}^{T} \pi_\theta(a_t | s_t) \, p(s_{t+1} | s_t, a_t)$

$\nabla_\theta \log P(\tau; \theta)$
$= \sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_t | s_t)$

$\nabla_\theta J(\theta)$
$= \mathbb{E}_{\tau \sim \pi_\theta}$
$\left[R(\tau)\sum_{t=1}^{T}\nabla_\theta \log \pi_\theta(a_t | s_t)\right]$


Reframing this back into NLP, a trajectory $\tau = (o_1,o_2,...o_T)$, and the state $s_t$ is very simply just $o_{< t}$

For GSM8K, we only have a "reward" at the end of the trajectory (when we have emitted \<EOS\> or hit the max generation count). The simple assumption in REINFORCE is that each token shares equal responsibility, $r_t = \frac{R(\tau)}{T}$.

Awesome, let's start by write a reward function to assign rewards (1 if correct, 0 if incorrect) to trajectories. (This is very similar to your evaluate_em method!).
Additionally, the function should return the average reward per batch.

When we do training, we will compute $R_{\tau_i} = R_{\tau_i} - b$.

This doesn't change the direction of the expected gradient, but helps training be more stable. If every batch has both correct and incorrect answers, we want to nudge up the probabilities of correct ones and down the incorrect ones relative to the batch average.

In [None]:
def reward_numeric(pred_texts, gold_answers, abs_tol=Decimal("1e-9"), rel_tol=Decimal("1e-9")):
    """
    Compute per-example numeric rewards for REINFORCE on GSM8K-style outputs. Returns a list of integer rewards corresponding to each GSM8k question, and average reward in the minibatch
    """
    n = min(len(pred_texts), len(gold_answers))
    if n == 0:
        return [], 0.0

    #TODO START: intialize rewards#
    #TOOD END#
    for gen, gold in zip(pred_texts[:n], gold_answers[:n]):
        pred = extract_answer(gen)
        gold_norm = normalize_num(gold)

        if pred is None or gold_norm is None:
            #TODO START:
            #TOOD END#

        if pred == gold_norm:
            #TODO START:
            #TOOD END#

        diff = abs(pred - gold_norm)
        denom = max(Decimal(1), abs(gold_norm))
        if diff <= abs_tol or diff / denom <= rel_tol:
            #TODO START:
            #TOOD END#
        else:
            #TODO START:
            #TOOD END#

    #TODO START:
    baseline =
    #TOOD END#
    return rewards, baseline


In [None]:
class PromptOnlyCollator:
    def __init__(self, tokenizer, pad_to_multiple_of=8):
        self.tok = tokenizer
        self.pad_to_multiple_of = pad_to_multiple_of
    def __call__(self, features):
        batch = self.tok.pad(
            {"input_ids": [f["input_ids"] for f in features]},
            padding=True, return_tensors="pt",
            pad_to_multiple_of=self.pad_to_multiple_of
        )
        batch["gold_answer"] = [f["gold_answer"] for f in features]
        return batch

In [None]:
gsm8k_train = load_dataset("openai/gsm8k", 'main', split = "train")
gsm8k_test = load_dataset("openai/gsm8k",'main', split = "test")

gsm8k_train = gsm8k_train.map(preprocess_gsm8k, remove_columns=gsm8k_test.column_names)
gsm8k_test = gsm8k_test.map(preprocess_gsm8k, remove_columns=gsm8k_test.column_names)

gsm8k_test_sample = gsm8k_test.select(range(200))


tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B", use_fast=True)
tok.padding_side = "left"
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

tok.truncation_side = "left"

MAX_LEN = 400
BUDGET = MAX_LEN - 1

SYS_IDS = tok(PROMPT_GSM8K, add_special_tokens=False)["input_ids"]

train_tok_rl = gsm8k_train.map(
    tokenize_batch,
    batched=True,
    batch_size=1024,
    num_proc=2,
    remove_columns=gsm8k_train.column_names,
    writer_batch_size=1024,
    desc="Tokenizing train set",
    fn_kwargs={"include_answer": False}
)

val_tok_rl = gsm8k_test.map(
    tokenize_batch,
    batched=True,
    batch_size=1024,
    num_proc=2,
    remove_columns=gsm8k_test.column_names,
    writer_batch_size=1024,
    desc="Tokenizing val set",
    fn_kwargs={"include_answer": False}
)

Here is the class we will define for the RL training. It extends the Pytorch Trainer.
There are **4** TODOs.

In [None]:
import torch
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, BitsAndBytesConfig
)
class REINFORCETrainer(Trainer):
    def __init__(self, *args, gen_kwargs=None, ref_model=None, kl_beta=0.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.gen_kwargs = gen_kwargs or dict(
            max_new_tokens=300,
            do_sample=True, top_p=0.9, temperature=0.7,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            return_dict_in_generate=True, output_scores=False,  # scores not needed now
        )
        self.ref_model = ref_model
        self.kl_beta = kl_beta

    @torch.no_grad()
    def _decode(self, seqs):
        return self.tokenizer.batch_decode(seqs, skip_special_tokens=True)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # Inputs from collator: input_ids (prompt only), attention_mask, gold_answer (list[str])
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        B = input_ids.size(0)
        eos_id = self.tokenizer.eos_token_id
        pad_id = self.tokenizer.pad_token_id

        # On-policy sampling
        with torch.no_grad():
            if not hasattr(self, "_checked_pad"):
              right_zero = (attention_mask[:, -1] == 0).any().item()
              assert not right_zero, "Right padding slipped in"
              self._checked_pad = True

            gen = model.generate(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 **self.gen_kwargs)
            full_ids = gen.sequences                         # [B, T_full]
            full_attn = (full_ids != pad_id).long()          # [B, T_full]
            prompt_lens = attention_mask.sum(dim=1)          # [B]

        T_full = full_ids.size(1)
        # forward pass (with grad) over the whole sampled sequence
        out = model(input_ids=full_ids[:, :-1], attention_mask=full_attn[:, :-1])
        logits = out.logits.float()
        # Sanitize the fp32 logits *before* log_softmax.
        # This fixes the Inf from the unstable LoRA weights.
        logits = torch.nan_to_num(
            logits,
            nan=0.0,
            posinf=torch.finfo(torch.float32).max,
            neginf=torch.finfo(torch.float32).min
        )
        logprobs = torch.log_softmax(logits, dim=-1)


        ar = torch.arange(T_full, device=full_ids.device).unsqueeze(0).expand(B, -1)  # [B,T]

        # ===== TODO =====
        # TASK 1: Build a mask of which positions to score for each sample.
        # Requirements:
        # - score only positions >= prompt_len
        # - stop at first EOS (included)
        # - ignore padding
        # Output shape: valid_mask: [B, T_full]
        #
        # Hints:
        # arange = torch.arange(T_full, device=full_ids.device)[None, :].expand(B, -1)
        # gen_region = arange >= prompt_lens[:, None]
        # then cut off after first eos
        #
        # TODO: implement valid_mask
        # valid_mask = ...

        # TASK 2: For each sample i, gather log p(sampled_token_i_t) at valid positions.
        # Remember: logits[:, t-1, :] predicts token at t (shift by one).

        # TASK 3: Sum the logprobs for each sample i over the valid positions only.

        # TASK 4: Compute the advantage and normalize it and compute the REINFORCE loss.

        logprob_sums = []
        for i in range(B):
            # TODO: get valid positions for this sample
            # pos = ...
            # TODO: skip position 0 (no prediction for t=0)
            # pos = pos[pos > 0]
            if #TODO fill in condition:
                safe_zero = (model.get_input_embeddings().weight[0, 0] * 0.0).to(dtype=logits.dtype)
                logprob_sums.append(safe_zero)
                continue
            # gather next-token logprobs for the actually sampled tokens
            tok_ids = full_ids[i, pos]                                                # [Ti]
            step_logits = logprobs[i, pos - 1, :]                                     # [Ti, V]
            step_logps = step_logits.gather(-1, tok_ids.view(-1, 1)).squeeze(-1)      # [Ti]
            step_logps = torch.where(
                torch.isfinite(step_logps),
                step_logps,
                torch.zeros_like(step_logps)
            )
            step_logps = step_logps.clamp(min=-20.0, max=0.0)

            Ti = pos.numel()

            # TODO: add logprobs to logprob_sums

        logprob_sums = torch.stack(logprob_sums, dim=0)                                # [B]

        assert logprob_sums.requires_grad, "logprob_sums lost grad (masking/shift produced all-empty rows)"

        # Rewards + batch baseline
        with torch.no_grad():
            texts = self._decode(full_ids)
            rewards_list, b = reward_numeric(texts, inputs["gold_answer"])
            rewards  = torch.tensor(rewards_list, dtype=logprob_sums.dtype, device=logprob_sums.device)  # [B]
            baseline = torch.tensor(b, dtype=logprob_sums.dtype, device=logprob_sums.device)

        em_batch = rewards.float().mean().item()
        self.log({"em_batch": em_batch})

        # TODO: Compute the advantage and normalize it.
        # 1. advantages = rewards - baseline
        # 2. optional: normalize to unit variance
        # 3. use advantage to weight logprob_sums in the REINFORCE loss

        # advantages = ...
        # std = advantages.std()
        # if torch.isfinite(std) and std > 0:
        #     advantages = ...
        # loss = ...

        # ===== TODO END =====
        # KL regularization with a frozen ref model could be added here.
        kl_loss = torch.tensor(0.0, device=loss.device)
        if self.ref_model is not None and self.kl_beta > 0:
            with torch.no_grad():
                # Third forward pass (no grad) with the REFERENCE model
                ref_out = self.ref_model(input_ids=full_ids[:, :-1], attention_mask=full_attn[:, :-1])
                ref_logprobs = torch.log_softmax(ref_out.logits.float(), dim=-1)

            # Mask for generated tokens (shifted)
            mask = valid[:, 1:].contiguous()
            mask_float = mask.float()

            target_ids = full_ids[:, 1:].unsqueeze(-1)  # [B, T_full-1, 1]

            # policy logp
            policy_logp = torch.gather(logprobs, -1, target_ids).squeeze(-1)  # [B, T-1]
            # ref logp
            ref_logp = torch.gather(ref_logprobs, -1, target_ids).squeeze(-1)  # [B, T-1]

            # clamp both to avoid -inf / inf
            policy_logp = policy_logp.clamp(min=-20.0, max=0.0)
            ref_logp    = ref_logp.clamp(min=-20.0, max=0.0)

            # KL per token
            kl_per_token = (policy_logp - ref_logp) * mask_float

            # sanitize per-token KL
            kl_per_token = torch.where(torch.isfinite(kl_per_token),
                                      kl_per_token,
                                      torch.zeros_like(kl_per_token))


            # Average KL *per sequence*
            kl_per_seq_sum = kl_per_token.sum(dim=1)
            valid_tokens_per_seq = mask_float.sum(dim=1).clamp(min=1)
            kl_per_seq = kl_per_seq_sum / valid_tokens_per_seq

            # Average KL *over the batch*
            kl_loss = kl_per_seq.mean()

        # Store component values for logging before combining
        policy_loss_item = loss.item()
        kl_loss_item = kl_loss.item()
        if not torch.isfinite(kl_loss):
            kl_loss = torch.tensor(0.0, device=loss.device)
        # Combine policy loss and KL regularization
        loss = loss + self.kl_beta * kl_loss

        self.log({
            "em_batch": em_batch,
            "policy_loss": policy_loss_item,
            "kl_loss": kl_loss_item,
            "loss": loss.item()
        })

        outputs = {
            "loss": loss,
            "policy_loss": policy_loss_item,
            "kl_loss": kl_loss_item
        }
        return (loss, outputs) if return_outputs else loss

Time to train! This step will take much longer than the SFT (and you very well may not get much improvement).

Think about how much work is required to get a signal (compared to SFT) and how dense that signal is (compared to SFT).

In [None]:
import torch
import os, logging, warnings
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers.utils import logging as hf_logging

hf_logging.set_verbosity_error()
hf_logging.enable_progress_bar()

tok = AutoTokenizer.from_pretrained("./gsm8k_1b_lora_merged", use_fast=True)
tok.pad_token = tok.eos_token
tok.padding_side = "left"
rl_collator = PromptOnlyCollator(tok)

model = AutoModelForCausalLM.from_pretrained(
    "./gsm8k_1b_lora_merged",
    device_map="auto",
    torch_dtype=torch.float16,
    attn_implementation="sdpa",
)
model.generation_config.return_dict_in_generate = True
model.generation_config.output_scores = False
model.config.use_cache = False

# LoRA config
lora_config = LoraConfig(
    r=16,                      # 8–16 common; raise if updates feel too weak
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj","k_proj","v_proj","o_proj",
        "gate_proj","up_proj","down_proj",
    ],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

ref_model = AutoModelForCausalLM.from_pretrained(
    "./gsm8k_1b_lora_merged",
    device_map="auto",
    torch_dtype=torch.float16,
    attn_implementation="sdpa",
)
ref_model.eval()
ref_model.requires_grad_(False)


args = TrainingArguments(
    output_dir="./gsm8k_rl_lora",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=16,
    max_grad_norm=1.0,
    learning_rate=1e-6,
    weight_decay=0.0,
    num_train_epochs=0.1,
    logging_steps=1,
    eval_strategy="steps",
    eval_steps=100,
    save_steps=500,
    save_total_limit=2,
    optim="adamw_torch",
    fp16=True,
    gradient_checkpointing=False,
    report_to="none",
    remove_unused_columns=False,
)


trainer = REINFORCETrainer(
    model=model,
    args=args,
    train_dataset=train_tok_rl,
    eval_dataset=val_tok_rl.select(range(64)),
    data_collator=rl_collator,
    processing_class=tok,
    gen_kwargs=dict(max_new_tokens=300, min_new_tokens=1, do_sample=True, top_p=0.9),
    ref_model=ref_model,
    kl_beta=0.1,
)

In [None]:
trainer.train()
trainer.save_model()
model.push_to_hub("<hf username>/gsm8krl-1b-lora1")
tok.push_to_hub("<hf username>/gsm8krl-1b-lora1")

In [None]:
from peft import AutoPeftModelForCausalLM
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B", use_fast=True)

# load the trained adapter (it knows the base model from its config)
peft_model = AutoPeftModelForCausalLM.from_pretrained(
    "<hf username>/gsm8krl-1b-lora1", torch_dtype="auto", device_map="auto"
)

# merge LoRA weights into the base weights and drop PEFT wrappers
merged = peft_model.merge_and_unload()

# save a standard HF model folder
merged.save_pretrained("./gsm8krl_1b_lora1_merged", safe_serialization=True)
tok.save_pretrained("./gsm8krl_1b_lora1_merged")

In [None]:
accuracy, outs, generated_answers = evaluate_gsm8k("./gsm8krl_1b_lora1_merged", gsm8k_test_sample["question"], gsm8k_test_sample["numeric_answer"], 16)
print(f"Accuracy: {accuracy}")

Please record your accuracy here:

\<ACCURACY\>

And 5 examples here:

\<SAMPLES\>

Do you notice anything different about the model behavior?

\<3-5 sentence reflection\>

In [None]:
#written by Daniel Zhang and Joey Huang