Vivek commited on
Commit
3adb47c
·
1 Parent(s): a3ead9a

uploaded prediction

Browse files
Files changed (3) hide show
  1. src/prediction.py +57 -0
  2. src/test.py +59 -0
  3. src/train.py +241 -0
src/prediction.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ import flax
5
+ import flax.linen as nn
6
+ from flax.core.frozen_dict import FrozenDict, unfreeze
7
+
8
+ from typing import Any, Optional, Tuple
9
+
10
+ from transformers import (
11
+ GPT2Config)
12
+
13
+ import transformers
14
+ from transformers import GPT2Tokenizer
15
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
16
+ from datasets import load_dataset,load_metric
17
+
18
+ from datasets import Dataset
19
+
20
+ from model_file import FlaxGPT2ForMultipleChoice
21
+
22
+ import logging
23
+
24
+ logger = logging.getLogger()
25
+ logger.setLevel(logging.INFO)
26
+
27
+ run_dataset=Dataset.from_csv('......')
28
+
29
+ def preprocess(example):
30
+ example['context&question']=example['context']+example['question']
31
+ example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
32
+ example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
33
+ return example
34
+
35
+ run_dataset=run_dataset.map(preprocess)
36
+
37
+ def tokenize(examples):
38
+ a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
39
+ a['labels']=examples['label']
40
+ return a
41
+
42
+ run_dataset=run_dataset.map(tokenize)
43
+
44
+ remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence']
45
+
46
+ run_dataset=run_dataset.remove_columns(remov_col)
47
+
48
+ model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
49
+
50
+ input_id=jnp.array(run_dataset['input_ids'])
51
+ att_mask=jnp.array(run_dataset['attention_mask'])
52
+
53
+ outputs=model(input_id,att_mask)
54
+
55
+ final_output=jnp.argmax(outputs,axis=-1)
56
+
57
+ logger.info(f"the predction of the dataset : {final_output}")
src/test.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ import flax
5
+ import flax.linen as nn
6
+ from flax.core.frozen_dict import FrozenDict, unfreeze
7
+
8
+ from typing import Any, Optional, Tuple
9
+
10
+ from transformers import (
11
+ GPT2Config)
12
+
13
+ import transformers
14
+ from transformers import GPT2Tokenizer
15
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
16
+ from datasets import load_dataset,load_metric
17
+
18
+ from model_file import FlaxGPT2ForMultipleChoice
19
+
20
+ import logging
21
+
22
+ logger = logging.getLogger()
23
+ logger.setLevel(logging.INFO)
24
+
25
+ dataset=load_dataset('cosmos_qa')
26
+
27
+ len_test_dataset=6963
28
+
29
+ test_dataset=dataset['test'].select(range(len_test_dataset))
30
+
31
+ def preprocess(example):
32
+ example['context&question']=example['context']+example['question']
33
+ example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
34
+ example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
35
+ return example
36
+
37
+ test_dataset=test_dataset.map(preprocess)
38
+
39
+ def tokenize(examples):
40
+ a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
41
+ a['labels']=examples['label']
42
+ return a
43
+
44
+ test_dataset=test_dataset.map(tokenize)
45
+
46
+ remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence']
47
+
48
+ test_dataset=test_dataset.remove_columns(remov_col)
49
+
50
+ model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
51
+
52
+ input_id=jnp.array(test_dataset['input_ids'])
53
+ att_mask=jnp.array(test_dataset['attention_mask'])
54
+
55
+ outputs=model(input_id,att_mask)
56
+
57
+ final_output=jnp.argmax(outputs,axis=-1)
58
+
59
+ logger.info(f"the predction of the test dataset : {final_output[:30]}")
src/train.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ print(jax.local_device_count())
3
+ import jax.numpy as jnp
4
+
5
+ import flax
6
+ import flax.linen as nn
7
+ from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key
8
+ from flax.training import train_state
9
+ from flax.metrics.tensorboard import SummaryWriter
10
+ from flax.training import checkpoints
11
+
12
+
13
+ import logging
14
+ import optax
15
+ import math
16
+ from tqdm import tqdm
17
+
18
+ from pathlib import Path
19
+ from typing import Callable
20
+ from itertools import chain
21
+ from flax.metrics import tensorboard
22
+
23
+ from datasets import load_dataset,load_metric
24
+ from transformers import GPT2Config,GPT2Tokenizer
25
+
26
+ from model_file import FlaxGPT2ForMultipleChoice
27
+
28
+ logger = logging.getLogger()
29
+ logger.setLevel(logging.INFO)
30
+
31
+ def main():
32
+
33
+
34
+ tokenizer=GPT2Tokenizer.from_pretrained('gpt2',pad_token='<|endoftext|>')
35
+
36
+ dataset=load_dataset('cosmos_qa')
37
+
38
+ def preprocess(example):
39
+ example['context&question']=example['context']+example['question']
40
+ example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
41
+ example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
42
+ return example
43
+
44
+ train_dataset=dataset['train'].map(preprocess)
45
+ validation_dataset=dataset['validation'].map(preprocess)
46
+ test_dataset=dataset['test'].map(preprocess)
47
+
48
+ #Remove after experiment
49
+ len_train_dataset=25262
50
+ len_validation_dataset=2985
51
+ len_test_dataset=6963
52
+
53
+ train_dataset=train_dataset.select(range(len_train_dataset))
54
+ test_dataset=test_dataset.select(range(len_test_dataset))
55
+ validation_dataset=validation_dataset.select(range(len_validation_dataset))
56
+
57
+ #remove_cols=train_dataset.column_names
58
+
59
+ def tokenize(examples):
60
+ a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
61
+ a['labels']=examples['label']
62
+ return a
63
+
64
+ train_dataset=train_dataset.map(tokenize)
65
+ validation_dataset=validation_dataset.map(tokenize)
66
+ test_dataset=test_dataset.map(tokenize)
67
+
68
+ remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence']
69
+
70
+ train_dataset=train_dataset.remove_columns(remov_col)
71
+ validation_dataset=validation_dataset.remove_columns(remov_col)
72
+ test_dataset=test_dataset.remove_columns(remov_col)
73
+
74
+ per_device_batch_size=4
75
+ seed=0
76
+ num_train_epochs=3
77
+ learning_rate=2e-5
78
+
79
+
80
+ total_batch_size = per_device_batch_size * jax.local_device_count()
81
+ print('The overall batch size (both for training and eval) is', total_batch_size)
82
+ num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
83
+ num_validation_steps=len(validation_dataset)//total_batch_size*num_train_epochs
84
+
85
+ learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)
86
+
87
+ class TrainState(train_state.TrainState):
88
+ logits_function:Callable=flax.struct.field(pytree_node=False)
89
+ loss_function:Callable=flax.struct.field(pytree_node=False)
90
+
91
+ def adamw(weight_decay):
92
+ return optax.adamw(learning_rate=learning_rate_function,b1=0.9,b2=0.99,eps=1e-6,weight_decay=weight_decay)
93
+
94
+ decay_path=lambda p:not any(x in p for x in ['bias','LayerNorm.weight'])
95
+
96
+ def traverse(function):
97
+ def mask(data):
98
+ flat=flax.traverse_util.flatten_dict(data)
99
+ return flax.traverse_util.unflatten_dict({k:function(k,v) for k,v in flat.items()})
100
+ return mask
101
+
102
+ gradient_transformation=optax.chain(
103
+ optax.masked(adamw(0.0),mask=traverse(lambda path,_:decay_path(path))),
104
+ optax.masked(adamw(0.01),mask=traverse(lambda path,_:not decay_path(path))))
105
+
106
+ def loss_function(logits,labels):
107
+ logits=flax.linen.log_softmax(logits)
108
+ xentropy=optax.softmax_cross_entropy(logits,onehot(labels,num_classes=4))
109
+ return jnp.mean(xentropy)
110
+
111
+ def eval_function(logits):
112
+ return logits.argmax(-1)
113
+
114
+ model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2-large',input_shape=(1,4,1))
115
+
116
+ state=TrainState.create(apply_fn=model.__call__,
117
+ params=model.params,
118
+ tx=gradient_transformation,
119
+ logits_function=eval_function,
120
+ loss_function=loss_function)
121
+
122
+ def train_step(state,batch,dropout_rng):
123
+ targets=batch.pop("label")
124
+ dropout_rng,new_dropout_rng=jax.random.split(dropout_rng)
125
+ def loss_function(params):
126
+ logits=state.apply_fn(**batch,params=params,dropout_rng=dropout_rng,train=True)[0]
127
+ loss=state.loss_function(logits,targets)
128
+ return loss
129
+ grad_function=jax.value_and_grad(loss_function)
130
+ loss,grad=grad_function(state.params)
131
+ grad=jax.lax.pmean(grad,"batch")
132
+ new_state=state.apply_gradients(grads=grad)
133
+ #Added.
134
+ logits=new_state.apply_fn(**batch,params=new_state.params,dropout_rng=dropout_rng,train=True)[0]
135
+ accuracy=jnp.equal(jnp.argmax(logits,axis=-1),targets)
136
+ metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step),'accuracy':accuracy},axis_name="batch")
137
+ return new_state,metrics,new_dropout_rng
138
+
139
+ parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
140
+
141
+ def eval_step(state, batch):
142
+ targets=batch.pop('label')
143
+ logits = state.apply_fn(**batch, params=state.params, train=False)
144
+ loss=state.loss_function(logits,targets)
145
+ predictions=state.logits_function(logits)
146
+ eval_accuracy=jnp.equal(predictions,targets)
147
+ #eval_acc=jnp.equal(predictions,targets)
148
+ metrics=jax.lax.pmean({"loss":loss,'accuracy':eval_accuracy},axis_name="batch")
149
+ #return state.logits_function(logits) #(8,4)
150
+ return targets,predictions,metrics
151
+
152
+ parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
153
+
154
+ def glue_train_data_loader(rng,dataset,batch_size):
155
+ steps_per_epoch=len_train_dataset//batch_size
156
+ perms=jax.random.permutation(rng,len(dataset))
157
+ perms=perms[:steps_per_epoch*batch_size]
158
+ perms=perms.reshape((steps_per_epoch,batch_size))
159
+ for perm in perms:
160
+ batch=dataset[perm]
161
+ batch={k:jnp.array(v) for k,v in batch.items()}
162
+ batch=shard(batch)
163
+ yield batch
164
+
165
+ rng=jax.random.PRNGKey(seed)
166
+ dropout_rngs=jax.random.split(rng,jax.local_device_count())
167
+
168
+ def glue_eval_data_loader(dataset, batch_size):
169
+ for i in range(len_validation_dataset // batch_size):
170
+ batch = dataset[i * batch_size : (i + 1) * batch_size]
171
+ batch = {k: jnp.array(v) for k, v in batch.items()}
172
+ batch = shard(batch)
173
+
174
+ yield batch
175
+
176
+ state = flax.jax_utils.replicate(state)
177
+ #metrics_list = list_metrics()
178
+
179
+ actual_task = "mnli"
180
+ metric = load_metric('glue', "mnli")
181
+ actual_taskmetric = load_metric('glue', actual_task)
182
+
183
+ workdir='../results_tensorboard'
184
+ summary_writer = tensorboard.SummaryWriter(workdir)
185
+ #summary_writer.hparams(dict(GPT2Config()))
186
+
187
+ logger.info(f"***** Running training *****")
188
+ logger.info(f" Num examples = {len_train_dataset}")
189
+ logger.info(f" Num Epochs = {num_train_epochs}")
190
+ logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
191
+ logger.info(f" Total train batch size = {total_batch_size}")
192
+ logger.info(f" Total optimization steps = {num_train_steps}")
193
+
194
+ for i, epoch in enumerate(tqdm(range(1, num_train_epochs+1), desc=f"Epoch ...", position=0, leave=True)):
195
+ rng, input_rng = jax.random.split(rng)
196
+ train_acc_metrics=[]
197
+ train_loss_metrics=[]
198
+ eval_acc_metrics=[]
199
+ eval_loss_metrics=[]
200
+ # train
201
+ with tqdm(total=len_train_dataset // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
202
+ for idx,batch in enumerate(glue_train_data_loader(input_rng, train_dataset, total_batch_size)):
203
+ state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
204
+ train_acc_metrics.append(jax.device_get(train_metric['accuracy']).mean().item())
205
+ train_loss_metrics.append(flax.jax_utils.unreplicate(train_metric)['loss'].item())
206
+ if idx%5==0:
207
+ summary_writer.scalar('train_loss',flax.jax_utils.unreplicate(train_metric)['loss'].item(),idx)
208
+ summary_writer.scalar('train_accuracy', jax.device_get(train_metric['accuracy']).mean().item(),idx)
209
+ if idx%20==0:
210
+ logger.info(f"train_step_loss{idx}: {flax.jax_utils.unreplicate(train_metric)['loss'].item()} train_step_acc{idx}: {jax.device_get(train_metric['accuracy']).mean().item()}")
211
+
212
+ progress_bar_train.update(1)
213
+
214
+ # evaluate
215
+ with tqdm(total=len_validation_dataset // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
216
+ for idx,batch in enumerate(glue_eval_data_loader(validation_dataset, total_batch_size)):
217
+ labels,predictions,eval_metric=parallel_eval_step(state, batch)
218
+ eval_acc_metrics.append(jax.device_get(eval_metric['accuracy']).mean().item())
219
+ eval_loss_metrics.append(flax.jax_utils.unreplicate(eval_metric)['loss'].item())
220
+ progress_bar_eval.update(1)
221
+ if idx%5==0:
222
+ logger.info(f"eval_step_loss {idx} : {flax.jax_utils.unreplicate(eval_metric)['loss'].item()} eval_step_acc {idx} : {jax.device_get(eval_metric['accuracy']).mean().item()}")
223
+ summary_writer.scalar('eval_loss : ', flax.jax_utils.unreplicate(eval_metric)['loss'].item(),idx)
224
+ summary_writer.scalar('eval_accuracy : ', jax.device_get(eval_metric['accuracy']).mean().item(),idx)
225
+
226
+ if jax.process_index() == 0:
227
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
228
+
229
+ model.save_pretrained(
230
+ '../',
231
+ params=params,
232
+ push_to_hub=True,
233
+ commit_message=f"Saving weights of epoch {epoch} at step {idx}",)
234
+
235
+ logger.info(f"---------------------Epoch {epoch} done-----------------")
236
+ logger.info(f"Train loss: {jax.device_get(jnp.array(train_loss_metrics)).mean().item()} Train accuracy: {jax.device_get(jnp.array(train_acc_metrics)).mean().item()}")
237
+ logger.info(f"Eval loss: {jax.device_get(jnp.array(eval_loss_metrics)).mean().item()} Eval accuracy: {jax.device_get(jnp.array(eval_acc_metrics)).mean().item()}")
238
+ summary_writer.flush()
239
+
240
+ if __name__ == "__main__":
241
+ main()