Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PROMPT_TEMPLATE.llama2_chat效果下降 #658

Open
dongjiancheng77 opened this issue May 7, 2024 · 7 comments
Open

PROMPT_TEMPLATE.llama2_chat效果下降 #658

dongjiancheng77 opened this issue May 7, 2024 · 7 comments
Assignees

Comments

@dongjiancheng77
Copy link

llama2_7b_qlora_alpaca_enzh_e3.py作为模板,qlora微调gsm8k,修改PROMPT_TEMPLATE.llama2_chat为PROMPT_TEMPLATE.llama3_chat,acc从62下降到28,可能是什么原因导致的?

@HIT-cwh
Copy link
Collaborator

HIT-cwh commented May 8, 2024

需要麻烦你提供下面两个信息:

  1. 您微调的是llama2 base 还是 llama2 chat
  2. 训练阶段修改完对话模板后,评测阶段有没有对应地修改对话模板

理论上其实不建议用qlora去微调base模型,因为qlora会冻结住embedding层,而base模型在预训练阶段又没见过对话模板中的token(例如llama3对话模板中的 <|start_header_id|> )。因此,模型用qlora微调后还是不认识对话模板中的token。

建议用全量微调的方式训练base模型的对话能力,或基于chat模型用lora/qlora微调

@HIT-cwh HIT-cwh self-assigned this May 8, 2024
@HIT-cwh
Copy link
Collaborator

HIT-cwh commented May 8, 2024

是因为对话模板中存在一些特殊token,例如 llama2 中的 [INST] 和 llama3 中的 <|start_header_id|>
llama2的词表里就有[INST],同样llama3的词表里有<|start_header_id|>这个字符串。所以他们在做token化的时候,这些特殊token能被token化成一个特定的token_id。如果用llama3的对话模板训llama2,那么llama3对话模板中的特色字符,llama2的tokenizer是不认识的,导致性能下滑。

因此,建议用llama3的对话模板训llama3 base或chat比较好。

@dongjiancheng77
Copy link
Author

  1. 我微调的是llama3 base,所以PROMPT_TEMPLATE.llama2_chat表现正常可能只是因为其中没有特殊token?
  2. 评测阶段都是用gsm8k标准对话模板,都没有对应修改

抱歉,我使用的模型是llama3 base

@dongjiancheng77
Copy link
Author

Copyright (c) OpenMMLab. All rights reserved.

import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import template_map_fn_factory
from mmengine.config import read_base
with read_base():
from .map_fn import custom_map_fn
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.parallel.sequence import SequenceParallelSampler
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE

#######################################################################

PART 1 Settings

#######################################################################

Model

pretrained_model_name_or_path = '/home/nfs02/model/llama-3-8b'
use_varlen_attn = False

Data

alpaca_en_path = '/home/nfs02/dongjc/MoDS/diverse-data-selection/seed-instructions8.json'
prompt_template = PROMPT_TEMPLATE.llama3_chat
max_length = 2048
pack_to_max_length = True

parallel

sequence_parallel_size = 1

Scheduler & Optimizer

batch_size = 1 # per_device
accumulative_counts = 16
accumulative_counts *= sequence_parallel_size
dataloader_num_workers = 0
max_epochs = 3
optim_type = AdamW
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip
warmup_ratio = 0.03

Save

save_steps = 1500
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)

Evaluate the generation performance during the training

evaluation_freq = 500
SYSTEM = SYSTEM_TEMPLATE.alpaca
evaluation_inputs = [
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
]

#######################################################################

PART 2 Model & Tokenizer

#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=SupervisedFinetune,
use_varlen_attn=use_varlen_attn,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))

#######################################################################

PART 3 Dataset & Dataloader

#######################################################################
alpaca_en = dict(
type=process_hf_dataset,
dataset=dict(type=load_dataset, path='json', data_files=dict(train=alpaca_en_path)),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=custom_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=False,
pack_to_max_length=pack_to_max_length,
use_varlen_attn=use_varlen_attn)

sampler = SequenceParallelSampler
if sequence_parallel_size > 1 else DefaultSampler
train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=alpaca_en,
sampler=dict(type=sampler, shuffle=True),
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################

PART 4 Scheduler & Optimizer

#######################################################################

optimizer

optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

learning policy

More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501

param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=warmup_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
begin=warmup_ratio * max_epochs,
end=max_epochs,
convert_to_iter_based=True)
]

train, val, test setting

train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################

PART 5 Runtime

#######################################################################

Log the dialogue periodically during the training process, optional

custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
system=SYSTEM,
prompt_template=prompt_template)
]

if use_varlen_attn:
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

configure default hooks

default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per save_steps.
checkpoint=dict(
type=CheckpointHook,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

configure environment

env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

set visualizer

visualizer = None

set log level

log_level = 'INFO'

load from which checkpoint

load_from = None

whether to resume training from the loaded checkpoint

resume = False

Defaults to use random seed and disable deterministic

randomness = dict(seed=None, deterministic=False)

set log processor

log_processor = dict(by_epoch=False)

@HIT-cwh
Copy link
Collaborator

HIT-cwh commented May 8, 2024

可以看下你的训练log吗?我有点担心是因为qlora学不会对话模板导致的。

@dongjiancheng77
Copy link
Author

好的
2024/05/07 11:11:41 - mmengine - INFO -

System environment:
sys.platform: linux
Python: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2139977810
GPU 0: NVIDIA GeForce RTX 3090
CUDA_HOME: /home/nfs03/cuda_tools/cuda-11.8
NVCC: Cuda compilation tools, release 11.8, V11.8.89
GCC: gcc (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
PyTorch: 2.1.2+cu121
PyTorch compiling details: PyTorch built with:

  • GCC 9.3

  • C++ Version: 201703

  • Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications

  • Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)

  • OpenMP 201511 (a.k.a. OpenMP 4.5)

  • LAPACK is enabled (usually provided by MKL)

  • NNPACK is enabled

  • CPU capability usage: AVX512

  • CUDA Runtime 12.1

  • NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90

  • CuDNN 8.7 (built against CUDA 11.8)

    • Built with CuDNN 8.9.2
  • Magma 2.6.1

  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.2, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

    TorchVision: 0.16.2+cu121
    OpenCV: 4.9.0
    MMEngine: 0.10.3

Runtime environment:
cudnn_benchmark: False
mp_cfg: {'mp_start_method': 'fork', 'opencv_num_threads': 0}
dist_cfg: {'backend': 'nccl'}
seed: 2139977810
deterministic: False
Distributed launcher: none
Distributed training: False
GPU number: 1

2024/05/07 11:11:41 - mmengine - INFO - Config:
SYSTEM = 'xtuner.utils.SYSTEM_TEMPLATE.alpaca'
accumulative_counts = 16
alpaca_en = dict(
dataset=dict(
data_files=dict(
train=
'/home/nfs02/dongjc/MoDS/diverse-data-selection/seed-instructions8.json'
),
path='json',
type='datasets.load_dataset'),
dataset_map_fn='<function custom_map_fn at 0x7f4380f770a0>',
max_length=2048,
pack_to_max_length=True,
remove_unused_columns=True,
shuffle_before_pack=False,
template_map_fn=dict(
template='xtuner.utils.PROMPT_TEMPLATE.llama3_chat',
type='xtuner.dataset.map_fns.template_map_fn_factory'),
tokenizer=dict(
padding_side='right',
pretrained_model_name_or_path='/home/nfs02/model/llama-3-8b',
trust_remote_code=True,
type='transformers.AutoTokenizer.from_pretrained'),
type='xtuner.dataset.process_hf_dataset',
use_varlen_attn=False)
alpaca_en_path = '/home/nfs02/dongjc/MoDS/diverse-data-selection/seed-instructions8.json'
batch_size = 1
betas = (
0.9,
0.999,
)
custom_hooks = [
dict(
tokenizer=dict(
padding_side='right',
pretrained_model_name_or_path='/home/nfs02/model/llama-3-8b',
trust_remote_code=True,
type='transformers.AutoTokenizer.from_pretrained'),
type='xtuner.engine.hooks.DatasetInfoHook'),
dict(
evaluation_inputs=[
'请给我介绍五个上海的景点',
'Please tell me five scenic spots in Shanghai',
],
every_n_iters=500,
prompt_template='xtuner.utils.PROMPT_TEMPLATE.llama3_chat',
system='xtuner.utils.SYSTEM_TEMPLATE.alpaca',
tokenizer=dict(
padding_side='right',
pretrained_model_name_or_path='/home/nfs02/model/llama-3-8b',
trust_remote_code=True,
type='transformers.AutoTokenizer.from_pretrained'),
type='xtuner.engine.hooks.EvaluateChatHook'),
]
custom_map_fn = '<function custom_map_fn at 0x7f4380f770a0>'
dataloader_num_workers = 0
default_hooks = dict(
checkpoint=dict(
by_epoch=False,
interval=1500,
max_keep_ckpts=2,
type='mmengine.hooks.CheckpointHook'),
logger=dict(
interval=10,
log_metric_by_epoch=False,
type='mmengine.hooks.LoggerHook'),
param_scheduler=dict(type='mmengine.hooks.ParamSchedulerHook'),
sampler_seed=dict(type='mmengine.hooks.DistSamplerSeedHook'),
timer=dict(type='mmengine.hooks.IterTimerHook'))
env_cfg = dict(
cudnn_benchmark=False,
dist_cfg=dict(backend='nccl'),
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
evaluation_freq = 500
evaluation_inputs = [
'请给我介绍五个上海的景点',
'Please tell me five scenic spots in Shanghai',
]
launcher = 'none'
load_from = None
log_level = 'INFO'
log_processor = dict(by_epoch=False)
lr = 0.0002
max_epochs = 3
max_length = 2048
max_norm = 1
model = dict(
llm=dict(
pretrained_model_name_or_path='/home/nfs02/model/llama-3-8b',
quantization_config=dict(
bnb_4bit_compute_dtype='torch.float16',
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
llm_int8_has_fp16_weight=False,
llm_int8_threshold=6.0,
load_in_4bit=True,
load_in_8bit=False,
type='transformers.BitsAndBytesConfig'),
torch_dtype='torch.float16',
trust_remote_code=True,
type='transformers.AutoModelForCausalLM.from_pretrained'),
lora=dict(
bias='none',
lora_alpha=16,
lora_dropout=0.1,
r=64,
task_type='CAUSAL_LM',
type='peft.LoraConfig'),
type='xtuner.model.SupervisedFinetune',
use_varlen_attn=False)
optim_type = 'torch.optim.AdamW'
optim_wrapper = dict(
accumulative_counts=16,
clip_grad=dict(error_if_nonfinite=False, max_norm=1),
dtype='float16',
loss_scale='dynamic',
optimizer=dict(
betas=(
0.9,
0.999,
),
lr=0.0002,
type='torch.optim.AdamW',
weight_decay=0),
type='mmengine.optim.AmpOptimWrapper')
pack_to_max_length = True
param_scheduler = [
dict(
begin=0,
by_epoch=True,
convert_to_iter_based=True,
end=0.09,
start_factor=1e-05,
type='mmengine.optim.LinearLR'),
dict(
begin=0.09,
by_epoch=True,
convert_to_iter_based=True,
end=3,
eta_min=0.0,
type='mmengine.optim.CosineAnnealingLR'),
]
pretrained_model_name_or_path = '/home/nfs02/model/llama-3-8b'
prompt_template = 'xtuner.utils.PROMPT_TEMPLATE.llama3_chat'
randomness = dict(deterministic=False, seed=None)
resume = False
sampler = 'mmengine.dataset.DefaultSampler'
save_steps = 1500
save_total_limit = 2
sequence_parallel_size = 1
tokenizer = dict(
padding_side='right',
pretrained_model_name_or_path='/home/nfs02/model/llama-3-8b',
trust_remote_code=True,
type='transformers.AutoTokenizer.from_pretrained')
train_cfg = dict(max_epochs=3, type='xtuner.engine.runner.TrainLoop')
train_dataloader = dict(
batch_size=1,
collate_fn=dict(
type='xtuner.dataset.collate_fns.default_collate_fn',
use_varlen_attn=False),
dataset=dict(
dataset=dict(
data_files=dict(
train=
'/home/nfs02/dongjc/MoDS/diverse-data-selection/seed-instructions8.json'
),
path='json',
type='datasets.load_dataset'),
dataset_map_fn='<function custom_map_fn at 0x7f4380f770a0>',
max_length=2048,
pack_to_max_length=True,
remove_unused_columns=True,
shuffle_before_pack=False,
template_map_fn=dict(
template='xtuner.utils.PROMPT_TEMPLATE.llama3_chat',
type='xtuner.dataset.map_fns.template_map_fn_factory'),
tokenizer=dict(
padding_side='right',
pretrained_model_name_or_path='/home/nfs02/model/llama-3-8b',
trust_remote_code=True,
type='transformers.AutoTokenizer.from_pretrained'),
type='xtuner.dataset.process_hf_dataset',
use_varlen_attn=False),
num_workers=0,
sampler=dict(shuffle=True, type='mmengine.dataset.DefaultSampler'))
use_varlen_attn = False
visualizer = None
warmup_ratio = 0.03
weight_decay = 0
work_dir = './work_dirs/llama2_7b_qlora_alpaca_e332'

2024/05/07 11:11:42 - mmengine - WARNING - Failed to search registry with scope "mmengine" in the "builder" registry tree. As a workaround, the current "builder" registry in "xtuner" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmengine" is a correct scope, or whether the registry is initialized.
2024/05/07 11:15:56 - mmengine - INFO - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.
2024/05/07 11:15:59 - mmengine - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH ) RuntimeInfoHook
(BELOW_NORMAL) LoggerHook

before_train:
(VERY_HIGH ) RuntimeInfoHook
(NORMAL ) IterTimerHook
(NORMAL ) DatasetInfoHook
(LOW ) EvaluateChatHook
(VERY_LOW ) CheckpointHook

before_train_epoch:
(VERY_HIGH ) RuntimeInfoHook
(NORMAL ) IterTimerHook
(NORMAL ) DistSamplerSeedHook

before_train_iter:
(VERY_HIGH ) RuntimeInfoHook
(NORMAL ) IterTimerHook

after_train_iter:
(VERY_HIGH ) RuntimeInfoHook
(NORMAL ) IterTimerHook
(BELOW_NORMAL) LoggerHook
(LOW ) ParamSchedulerHook
(LOW ) EvaluateChatHook
(VERY_LOW ) CheckpointHook

after_train_epoch:
(NORMAL ) IterTimerHook
(LOW ) ParamSchedulerHook
(VERY_LOW ) CheckpointHook

before_val:
(VERY_HIGH ) RuntimeInfoHook
(NORMAL ) DatasetInfoHook

before_val_epoch:
(NORMAL ) IterTimerHook

before_val_iter:
(NORMAL ) IterTimerHook

after_val_iter:
(NORMAL ) IterTimerHook
(BELOW_NORMAL) LoggerHook

after_val_epoch:
(VERY_HIGH ) RuntimeInfoHook
(NORMAL ) IterTimerHook
(BELOW_NORMAL) LoggerHook
(LOW ) ParamSchedulerHook
(VERY_LOW ) CheckpointHook

after_val:
(VERY_HIGH ) RuntimeInfoHook
(LOW ) EvaluateChatHook

after_train:
(VERY_HIGH ) RuntimeInfoHook
(LOW ) EvaluateChatHook
(VERY_LOW ) CheckpointHook

before_test:
(VERY_HIGH ) RuntimeInfoHook
(NORMAL ) DatasetInfoHook

before_test_epoch:
(NORMAL ) IterTimerHook

before_test_iter:
(NORMAL ) IterTimerHook

after_test_iter:
(NORMAL ) IterTimerHook
(BELOW_NORMAL) LoggerHook

after_test_epoch:
(VERY_HIGH ) RuntimeInfoHook
(NORMAL ) IterTimerHook
(BELOW_NORMAL) LoggerHook

after_test:
(VERY_HIGH ) RuntimeInfoHook

after_run:
(BELOW_NORMAL) LoggerHook

2024/05/07 11:16:31 - mmengine - WARNING - Dataset Dataset has no metainfo. dataset_meta in visualizer will be None.
2024/05/07 11:16:32 - mmengine - INFO - Num train samples 682
2024/05/07 11:16:32 - mmengine - INFO - train example:
2024/05/07 11:16:32 - mmengine - INFO - <|begin_of_text|><|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

In the beginning, Betty has only 100 / 2 = $<<100/2=50>>50.
Betty's grandparents gave her 15 * 2 = $<<15*2=30>>30.
This means, Betty needs 100 - 50 - 30 - 15 = $<<100-50-30-15=5>>5 more.

5<|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Gabby is saving money to buy a new makeup set. The makeup set costs $65 and she already has $35. Gabby’s mom gives her an additional $20. How much money does Gabby need to buy the set?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Deduct the amount saved from the total cost of the set. $65 - $35 = $&lt;<65-35=30>>30
Deduct the amount of money Gabby received from her mom. $30 - $20 = $&lt;<30-20=10>>10

10<|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Vanessa wants to buy a dress she saw at the mall, which costs $80, and she already has $20 in savings. Her parents give her $30 every week, but she also spends $10 each weekend at the arcades. How many weeks will she have to wait until she can gather enough money to buy the dress?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Vanessa needs $80 – $20 = $&lt;<80-20=60>>60 to buy the dress.
She manages to gather $30 - $10 = $&lt;<30-10=20>>20 each week
The number of weeks she has to wait is 60 ÷ 20 = <<60/20=3>>3 weeks.

3<|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Barbara wants to save up for a new wristwatch that costs $100. Her parents give her an allowance of $5 a week and she can either save it all up for the watch or spend it as she wishes. 10 weeks pass and due to spending some of her money on ice cream, Barbara currently only has $20. How many more weeks does she need to save for a watch if she stops spending on other things right now?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Since barbara needs $100, and she currently has $20, she is remaining with $100-$20 = $<<100-20=80>>80 to save
With an allowance of $5 each week, she'll need to save for 80/5 = 16 weeks

16<|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Zoe wants to go on the field trip to Washington DC with her middle school this spring and the cost is $485. Her grandma gave her $250 toward her fees and she must earn the rest by selling candy bars. She makes $1.25 for every candy bar she sells. How many candy bars does Zoe need to sell to earn the money for the trip?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Since Zoe’s grandma gave her some money toward her trip, she needs to come up with $485 – $250 = $&lt;<485-250=235>>235.
This means she must sell $235 / $1.25/candy bar = <<235/1.25=188>>188 candy bars.

188<|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Whitney’s mom gave her two $20 bills to spend at the school book fair. Whitney has decided to buy 2 posters, 3 notebooks, and 2 bookmarks. Each poster costs $5, each notebook costs $4, and each bookmark costs $2. How much money, in dollars, will Whitney have left over after the purchase?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Whitney is buying 2 posters for $5 each, so the posters will cost 2*$5= $<<25=10>>10 total cost for posters.
Whitney is buying 3 notebooks for $4 each, so the notebooks will cost 3
$4= $&lt;<34=12>>12 total cost for notebooks.
Whitney is buying 2 notebooks for $2, so the bookmarks 2
$2= $&lt;<22=4>>4 total cost for bookmarks.
Since Whitney is paying $10 for posters, $12 for notebooks, and $4 for bookmarks, her total purchase will cost $10+$12+$4= $<<10+12+4=26>>26 total purchase cost.
Whitney’s mom gave her 2 $20 bills, so she will be paying with 2
$20=$&lt;<2*20=40>>40 total payment.
Since Whitney is paying with $40 and her purchase cost will be $26, she will have $40-$26= $<<40-26=14>>14 left over after the purchase.

14<|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Rose is an aspiring artist. She wants a paintbrush that costs $2.40, a set of paints that costs $9.20, and an easel that costs $6.50 so she can do some paintings. Rose already has $7.10. How much more money does Rose need?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The total cost of the paintbrush, the paints, and the easel was $2.40 + $9.20 + $6.50 = $&lt;<2.4+9.2+6.5=18.10>>18.10.
Rose needs $18.10 - $7.10 = $&lt;<18.10-7.10=11>>11 more.

11<|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Maggie has an after-school job that pays her $5.00 for every magazine subscription she can sell. She sells 4 to her parents, 1 to her grandfather, 2 to the next-door neighbor and twice that amount to another neighbor. How much money did Maggie earn?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Maggie sells 2 subscriptions to a neighbor and twice that amount to another neighbor so she sells 22 = <<22=4>>4 subscriptions to the other neighbor
In total, Maggie sells 4+1+2+4 = <<4+1+2+4=11>>11 subscriptions
She earns $5.00 per subscription she sells so she earns 511 = $<<511=55.00>>55.00

55<|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Tina saved $27 in June, $14 in July, and $21 in August. Then Tina spent $5 on books and $17 on new shoes. How much money does Tina have left?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The total amount of money saved is $27 + $14 + $21 = $&lt;<27+14+21=62>>62.
The total amount spent on books and new shoes is $5 + $17 = $&lt;<5+17=22>>22.
Tina has $62 − $22 = $40 left.

40<|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Kate saved $27 in March. She saved $13 in April and $28 in May. Then Kate spent $49 on a keyboard and $5 on a mouse. How much money does Kate have left?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The total amount of money saved is $27 + $13 + $28 = $&lt;<27+13+28=68>>68.
The total cost of the two products is $49 + $5 = $&lt;<49+5=54>>54.
Kate has $68 − $54 = $14 left.

14<|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Hannah sold 40 pieces of cookies for $0.8 each and 30 cupcakes for $2 each. She used the money to buy 2 sets of measuring spoons for $6.5 each. How much money does she have left?
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Hannah's earnings from the cookies is 40 x $0.8 = $&lt;<400.8=32>>32.
Her earnings from the cupcakes is 30 x $2 = $<<30
2=60>>60.
Her total earnings for the cupcakes and cookies is $32 + $60 = $&lt;<32+60=92>>92.
The cost of 2 sets of measuring spoons is 2 x $6.5 = $&lt;<2*6.5=13>>13.
So, Hannah has
2024/05/07 11:16:32 - mmengine - INFO - before_train in EvaluateChatHook.
2024/05/07 11:16:44 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

请给我介绍五个上海的景点<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Please give me a list of five Shanghai sights.

Please give me a list of five Shanghai sights.

Please give me a list of five Shanghai sights.

Please give me a list of five Shanghai sights.

Please give me a list of five Shanghai sights.

2024/05/07 11:16:50 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Please tell me five scenic spots in Shanghai<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Please tell me five scenic spots in Shanghai.

Please tell me five scenic spots in Shanghai.

Please tell me five scenic spots in Shanghai.

Please tell me five scenic spots in Shanghai.

Please tell me five scenic spots in Shanghai.

Please tell me five scenic

2024/05/07 11:16:50 - mmengine - WARNING - "FileClient" will be deprecated in future. Please use io functions in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
2024/05/07 11:16:50 - mmengine - WARNING - "HardDiskBackend" is the alias of "LocalBackend" and the former will be deprecated in future.
2024/05/07 11:16:50 - mmengine - INFO - Checkpoints will be saved to /home/nfs03/dongjc/32/work_dirs/llama2_7b_qlora_alpaca_e332.
2024/05/07 11:17:20 - mmengine - INFO - Iter(train) [ 10/2046] lr: 3.0002e-05 eta: 1:41:54 time: 3.0030 data_time: 0.0241 memory: 13995 loss: 0.9421
2024/05/07 11:17:51 - mmengine - INFO - Iter(train) [ 20/2046] lr: 6.3335e-05 eta: 1:42:16 time: 3.0543 data_time: 0.0190 memory: 13995 loss: 0.9169 grad_norm: nan
2024/05/07 11:18:20 - mmengine - INFO - Iter(train) [ 30/2046] lr: 9.6668e-05 eta: 1:40:34 time: 2.9234 data_time: 0.0188 memory: 13993 loss: 0.8919 grad_norm: nan
2024/05/07 11:18:51 - mmengine - INFO - Iter(train) [ 40/2046] lr: 1.3000e-04 eta: 1:40:59 time: 3.1012 data_time: 0.0232 memory: 15266 loss: 0.8997 grad_norm: nan
2024/05/07 11:19:21 - mmengine - INFO - Iter(train) [ 50/2046] lr: 1.6333e-04 eta: 1:39:56 time: 2.9391 data_time: 0.0196 memory: 15266 loss: 0.9710 grad_norm: nan
2024/05/07 11:19:50 - mmengine - INFO - Iter(train) [ 60/2046] lr: 1.9667e-04 eta: 1:39:06 time: 2.9435 data_time: 0.0191 memory: 15266 loss: 0.8769 grad_norm: nan
2024/05/07 11:20:20 - mmengine - INFO - Iter(train) [ 70/2046] lr: 1.9999e-04 eta: 1:38:23 time: 2.9475 data_time: 0.0189 memory: 15266 loss: 0.8763 grad_norm: nan
2024/05/07 11:20:49 - mmengine - INFO - Iter(train) [ 80/2046] lr: 1.9996e-04 eta: 1:37:44 time: 2.9522 data_time: 0.0194 memory: 15266 loss: 0.8272 grad_norm: nan
2024/05/07 11:21:19 - mmengine - INFO - Iter(train) [ 90/2046] lr: 1.9990e-04 eta: 1:37:06 time: 2.9449 data_time: 0.0179 memory: 15266 loss: 0.8514 grad_norm: nan
2024/05/07 11:21:48 - mmengine - INFO - Iter(train) [ 100/2046] lr: 1.9982e-04 eta: 1:36:30 time: 2.9483 data_time: 0.0171 memory: 15266 loss: 0.8082 grad_norm: nan
2024/05/07 11:22:17 - mmengine - INFO - Iter(train) [ 110/2046] lr: 1.9971e-04 eta: 1:35:56 time: 2.9480 data_time: 0.0193 memory: 15266 loss: 0.7965 grad_norm: nan
2024/05/07 11:22:47 - mmengine - INFO - Iter(train) [ 120/2046] lr: 1.9958e-04 eta: 1:35:22 time: 2.9517 data_time: 0.0205 memory: 15266 loss: 0.7757 grad_norm: nan
2024/05/07 11:23:17 - mmengine - INFO - Iter(train) [ 130/2046] lr: 1.9942e-04 eta: 1:34:50 time: 2.9513 data_time: 0.0196 memory: 15266 loss: 0.7594 grad_norm: nan
2024/05/07 11:23:47 - mmengine - INFO - Iter(train) [ 140/2046] lr: 1.9924e-04 eta: 1:34:29 time: 3.0344 data_time: 0.0245 memory: 15266 loss: 0.7240 grad_norm: nan
2024/05/07 11:24:16 - mmengine - INFO - Iter(train) [ 150/2046] lr: 1.9903e-04 eta: 1:33:56 time: 2.9508 data_time: 0.0215 memory: 15266 loss: 0.7447 grad_norm: nan
2024/05/07 11:24:46 - mmengine - INFO - Iter(train) [ 160/2046] lr: 1.9880e-04 eta: 1:33:24 time: 2.9528 data_time: 0.0210 memory: 15266 loss: 0.7693 grad_norm: nan
2024/05/07 11:25:15 - mmengine - INFO - Iter(train) [ 170/2046] lr: 1.9854e-04 eta: 1:32:52 time: 2.9553 data_time: 0.0234 memory: 15266 loss: 0.7103 grad_norm: nan
2024/05/07 11:25:45 - mmengine - INFO - Iter(train) [ 180/2046] lr: 1.9826e-04 eta: 1:32:20 time: 2.9480 data_time: 0.0195 memory: 15266 loss: 0.6376 grad_norm: nan
2024/05/07 11:26:14 - mmengine - INFO - Iter(train) [ 190/2046] lr: 1.9796e-04 eta: 1:31:49 time: 2.9544 data_time: 0.0220 memory: 15266 loss: 0.6903 grad_norm: nan
2024/05/07 11:26:44 - mmengine - INFO - Iter(train) [ 200/2046] lr: 1.9762e-04 eta: 1:31:18 time: 2.9497 data_time: 0.0222 memory: 15266 loss: 0.6554 grad_norm: nan
2024/05/07 11:27:13 - mmengine - INFO - Iter(train) [ 210/2046] lr: 1.9727e-04 eta: 1:30:47 time: 2.9490 data_time: 0.0222 memory: 15266 loss: 0.7249 grad_norm: nan
2024/05/07 11:27:43 - mmengine - INFO - Iter(train) [ 220/2046] lr: 1.9689e-04 eta: 1:30:15 time: 2.9469 data_time: 0.0230 memory: 15266 loss: 0.7279 grad_norm: nan
2024/05/07 11:28:12 - mmengine - INFO - Iter(train) [ 230/2046] lr: 1.9649e-04 eta: 1:29:44 time: 2.9511 data_time: 0.0223 memory: 15266 loss: 0.6784 grad_norm: nan
2024/05/07 11:28:42 - mmengine - INFO - Iter(train) [ 240/2046] lr: 1.9606e-04 eta: 1:29:14 time: 2.9495 data_time: 0.0229 memory: 15266 loss: 0.7469 grad_norm: nan
2024/05/07 11:29:11 - mmengine - INFO - Iter(train) [ 250/2046] lr: 1.9561e-04 eta: 1:28:43 time: 2.9452 data_time: 0.0196 memory: 15266 loss: 0.7041 grad_norm: nan
2024/05/07 11:29:41 - mmengine - INFO - Iter(train) [ 260/2046] lr: 1.9513e-04 eta: 1:28:12 time: 2.9488 data_time: 0.0206 memory: 15266 loss: 0.7529 grad_norm: nan
2024/05/07 11:30:10 - mmengine - INFO - Iter(train) [ 270/2046] lr: 1.9463e-04 eta: 1:27:41 time: 2.9499 data_time: 0.0258 memory: 15266 loss: 0.7342 grad_norm: nan
2024/05/07 11:30:40 - mmengine - INFO - Iter(train) [ 280/2046] lr: 1.9411e-04 eta: 1:27:11 time: 2.9531 data_time: 0.0211 memory: 15266 loss: 0.6903 grad_norm: nan
2024/05/07 11:31:09 - mmengine - INFO - Iter(train) [ 290/2046] lr: 1.9356e-04 eta: 1:26:40 time: 2.9449 data_time: 0.0202 memory: 15266 loss: 0.6879 grad_norm: nan
2024/05/07 11:31:39 - mmengine - INFO - Iter(train) [ 300/2046] lr: 1.9299e-04 eta: 1:26:10 time: 2.9444 data_time: 0.0201 memory: 15266 loss: 0.7110 grad_norm: nan
2024/05/07 11:32:08 - mmengine - INFO - Iter(train) [ 310/2046] lr: 1.9240e-04 eta: 1:25:39 time: 2.9443 data_time: 0.0202 memory: 15266 loss: 0.6765 grad_norm: nan
2024/05/07 11:32:38 - mmengine - INFO - Iter(train) [ 320/2046] lr: 1.9178e-04 eta: 1:25:09 time: 2.9485 data_time: 0.0216 memory: 15266 loss: 0.7005 grad_norm: nan
2024/05/07 11:33:07 - mmengine - INFO - Iter(train) [ 330/2046] lr: 1.9114e-04 eta: 1:24:38 time: 2.9436 data_time: 0.0216 memory: 15266 loss: 0.7194 grad_norm: nan
2024/05/07 11:33:37 - mmengine - INFO - Iter(train) [ 340/2046] lr: 1.9048e-04 eta: 1:24:08 time: 2.9426 data_time: 0.0192 memory: 15266 loss: 0.6991 grad_norm: nan
2024/05/07 11:34:06 - mmengine - INFO - Iter(train) [ 350/2046] lr: 1.8979e-04 eta: 1:23:37 time: 2.9376 data_time: 0.0186 memory: 15266 loss: 0.6662 grad_norm: nan
2024/05/07 11:34:35 - mmengine - INFO - Iter(train) [ 360/2046] lr: 1.8908e-04 eta: 1:23:07 time: 2.9394 data_time: 0.0205 memory: 15266 loss: 0.6635 grad_norm: nan
2024/05/07 11:35:05 - mmengine - INFO - Iter(train) [ 370/2046] lr: 1.8835e-04 eta: 1:22:37 time: 2.9406 data_time: 0.0216 memory: 15266 loss: 0.6787 grad_norm: nan
2024/05/07 11:35:37 - mmengine - INFO - Iter(train) [ 380/2046] lr: 1.8760e-04 eta: 1:22:20 time: 3.2504 data_time: 0.0165 memory: 15266 loss: 0.6992 grad_norm: nan
2024/05/07 11:36:07 - mmengine - INFO - Iter(train) [ 390/2046] lr: 1.8683e-04 eta: 1:21:49 time: 2.9403 data_time: 0.0242 memory: 15266 loss: 0.6530 grad_norm: nan
2024/05/07 11:36:36 - mmengine - INFO - Iter(train) [ 400/2046] lr: 1.8603e-04 eta: 1:21:18 time: 2.9417 data_time: 0.0229 memory: 15266 loss: 0.6449 grad_norm: nan
2024/05/07 11:37:05 - mmengine - INFO - Iter(train) [ 410/2046] lr: 1.8521e-04 eta: 1:20:48 time: 2.9321 data_time: 0.0214 memory: 15266 loss: 0.6579 grad_norm: nan
2024/05/07 11:37:35 - mmengine - INFO - Iter(train) [ 420/2046] lr: 1.8437e-04 eta: 1:20:17 time: 2.9347 data_time: 0.0216 memory: 15266 loss: 0.6713 grad_norm: nan
2024/05/07 11:38:04 - mmengine - INFO - Iter(train) [ 430/2046] lr: 1.8351e-04 eta: 1:19:46 time: 2.9361 data_time: 0.0229 memory: 15266 loss: 0.7113 grad_norm: nan
2024/05/07 11:38:34 - mmengine - INFO - Iter(train) [ 440/2046] lr: 1.8263e-04 eta: 1:19:16 time: 2.9440 data_time: 0.0318 memory: 15266 loss: 0.6479 grad_norm: nan
2024/05/07 11:39:03 - mmengine - INFO - Iter(train) [ 450/2046] lr: 1.8173e-04 eta: 1:18:45 time: 2.9310 data_time: 0.0188 memory: 15266 loss: 0.6951 grad_norm: nan
2024/05/07 11:39:32 - mmengine - INFO - Iter(train) [ 460/2046] lr: 1.8081e-04 eta: 1:18:15 time: 2.9306 data_time: 0.0195 memory: 15266 loss: 0.6637 grad_norm: nan
2024/05/07 11:40:03 - mmengine - INFO - Iter(train) [ 470/2046] lr: 1.7987e-04 eta: 1:17:48 time: 3.0641 data_time: 0.1564 memory: 15266 loss: 0.7217 grad_norm: nan
2024/05/07 11:40:32 - mmengine - INFO - Iter(train) [ 480/2046] lr: 1.7890e-04 eta: 1:17:18 time: 2.9304 data_time: 0.0226 memory: 15266 loss: 0.6386 grad_norm: nan
2024/05/07 11:41:01 - mmengine - INFO - Iter(train) [ 490/2046] lr: 1.7792e-04 eta: 1:16:47 time: 2.9231 data_time: 0.0182 memory: 15266 loss: 0.6352 grad_norm: nan
2024/05/07 11:41:31 - mmengine - INFO - Iter(train) [ 500/2046] lr: 1.7692e-04 eta: 1:16:16 time: 2.9280 data_time: 0.0216 memory: 15266 loss: 0.6962 grad_norm: nan
2024/05/07 11:41:31 - mmengine - INFO - after_train_iter in EvaluateChatHook.
2024/05/07 11:43:09 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

请给我介绍五个上海的景点<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Here are five Shanghai attractions.

  1. The Bund
  2. Yu Garden
  3. Shanghai Museum
  4. Shanghai Tower
  5. Oriental Pearl TV Tower

1-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5-5

2024/05/07 11:44:47 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Please tell me five scenic spots in Shanghai<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Five scenic spots in Shanghai are the Bund, the Oriental Pearl Tower, the Yuyuan Garden, the Shanghai Museum, and the Shanghai Science and Technology Museum.

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

5 scenic spots in Shanghai

2024/05/07 11:45:16 - mmengine - INFO - Iter(train) [ 510/2046] lr: 1.7590e-04 eta: 1:25:37 time: 22.5450 data_time: 19.6772 memory: 15266 loss: 0.6919 grad_norm: nan
2024/05/07 11:45:45 - mmengine - INFO - Iter(train) [ 520/2046] lr: 1.7486e-04 eta: 1:24:50 time: 2.9098 data_time: 0.0234 memory: 15266 loss: 0.6479 grad_norm: nan
2024/05/07 11:46:14 - mmengine - INFO - Iter(train) [ 530/2046] lr: 1.7380e-04 eta: 1:24:05 time: 2.9166 data_time: 0.0211 memory: 15266 loss: 0.6724 grad_norm: nan
2024/05/07 11:46:44 - mmengine - INFO - Iter(train) [ 540/2046] lr: 1.7272e-04 eta: 1:23:20 time: 2.9215 data_time: 0.0221 memory: 15266 loss: 0.6974 grad_norm: nan
2024/05/07 11:47:13 - mmengine - INFO - Iter(train) [ 550/2046] lr: 1.7163e-04 eta: 1:22:36 time: 2.9239 data_time: 0.0201 memory: 15266 loss: 0.7089 grad_norm: nan
2024/05/07 11:47:42 - mmengine - INFO - Iter(train) [ 560/2046] lr: 1.7051e-04 eta: 1:21:53 time: 2.9212 data_time: 0.0198 memory: 15266 loss: 0.6493 grad_norm: nan
2024/05/07 11:48:11 - mmengine - INFO - Iter(train) [ 570/2046] lr: 1.6938e-04 eta: 1:21:10 time: 2.9249 data_time: 0.0240 memory: 15266 loss: 0.6526 grad_norm: nan
2024/05/07 11:48:41 - mmengine - INFO - Iter(train) [ 580/2046] lr: 1.6824e-04 eta: 1:20:27 time: 2.9221 data_time: 0.0201 memory: 15266 loss: 0.6829 grad_norm: nan
2024/05/07 11:49:10 - mmengine - INFO - Iter(train) [ 590/2046] lr: 1.6707e-04 eta: 1:19:45 time: 2.9195 data_time: 0.0185 memory: 15266 loss: 0.6081 grad_norm: nan
2024/05/07 11:49:39 - mmengine - INFO - Iter(train) [ 600/2046] lr: 1.6589e-04 eta: 1:19:04 time: 2.9272 data_time: 0.0224 memory: 15266 loss: 0.6401 grad_norm: nan
2024/05/07 11:50:08 - mmengine - INFO - Iter(train) [ 610/2046] lr: 1.6469e-04 eta: 1:18:22 time: 2.9232 data_time: 0.0197 memory: 15266 loss: 0.6764 grad_norm: nan
2024/05/07 11:50:37 - mmengine - INFO - Iter(train) [ 620/2046] lr: 1.6347e-04 eta: 1:17:42 time: 2.9225 data_time: 0.0214 memory: 15266 loss: 0.6710 grad_norm: nan
2024/05/07 11:51:07 - mmengine - INFO - Iter(train) [ 630/2046] lr: 1.6224e-04 eta: 1:17:01 time: 2.9235 data_time: 0.0226 memory: 15266 loss: 0.6631 grad_norm: nan
2024/05/07 11:51:36 - mmengine - INFO - Iter(train) [ 640/2046] lr: 1.6100e-04 eta: 1:16:21 time: 2.9190 data_time: 0.0192 memory: 15266 loss: 0.6409 grad_norm: nan
2024/05/07 11:52:05 - mmengine - INFO - Iter(train) [ 650/2046] lr: 1.5973e-04 eta: 1:15:41 time: 2.9207 data_time: 0.0218 memory: 15266 loss: 0.6175 grad_norm: nan
2024/05/07 11:52:34 - mmengine - INFO - Iter(train) [ 660/2046] lr: 1.5846e-04 eta: 1:15:02 time: 2.9233 data_time: 0.0220 memory: 15266 loss: 0.6973 grad_norm: nan
2024/05/07 11:53:04 - mmengine - INFO - Iter(train) [ 670/2046] lr: 1.5717e-04 eta: 1:14:22 time: 2.9243 data_time: 0.0237 memory: 15266 loss: 0.6745 grad_norm: nan
2024/05/07 11:53:33 - mmengine - INFO - Iter(train) [ 680/2046] lr: 1.5586e-04 eta: 1:13:43 time: 2.9193 data_time: 0.0189 memory: 15266 loss: 0.6450 grad_norm: nan
2024/05/07 11:53:39 - mmengine - INFO - Exp name: llama2_7b_qlora_alpaca_e332_20240507_111134
2024/05/07 11:53:39 - mmengine - WARNING - Reach the end of the dataloader, it will be restarted and continue to iterate. It is recommended to use mmengine.dataset.InfiniteSampler to enable the dataloader to iterate infinitely.
2024/05/07 11:54:04 - mmengine - INFO - Iter(train) [ 690/2046] lr: 1.5454e-04 eta: 1:13:09 time: 3.1221 data_time: 0.2212 memory: 15266 loss: 0.6800 grad_norm: nan
2024/05/07 11:54:33 - mmengine - INFO - Iter(train) [ 700/2046] lr: 1.5321e-04 eta: 1:12:30 time: 2.9242 data_time: 0.0252 memory: 15266 loss: 0.6869 grad_norm: nan
2024/05/07 11:55:02 - mmengine - INFO - Iter(train) [ 710/2046] lr: 1.5186e-04 eta: 1:11:52 time: 2.9258 data_time: 0.0227 memory: 15266 loss: 0.6760 grad_norm: nan
2024/05/07 11:55:32 - mmengine - INFO - Iter(train) [ 720/2046] lr: 1.5050e-04 eta: 1:11:14 time: 2.9235 data_time: 0.0209 memory: 15266 loss: 0.7334 grad_norm: nan
2024/05/07 11:56:01 - mmengine - INFO - Iter(train) [ 730/2046] lr: 1.4913e-04 eta: 1:10:37 time: 2.9199 data_time: 0.0206 memory: 15266 loss: 0.7073 grad_norm: nan
2024/05/07 11:56:30 - mmengine - INFO - Iter(train) [ 740/2046] lr: 1.4774e-04 eta: 1:09:59 time: 2.9244 data_time: 0.0233 memory: 15266 loss: 0.7682 grad_norm: nan
2024/05/07 11:56:59 - mmengine - INFO - Iter(train) [ 750/2046] lr: 1.4635e-04 eta: 1:09:22 time: 2.9184 data_time: 0.0210 memory: 15266 loss: 0.6898 grad_norm: nan
2024/05/07 11:57:29 - mmengine - INFO - Iter(train) [ 760/2046] lr: 1.4494e-04 eta: 1:08:45 time: 2.9234 data_time: 0.0220 memory: 15266 loss: 0.7508 grad_norm: nan
2024/05/07 11:57:58 - mmengine - INFO - Iter(train) [ 770/2046] lr: 1.4352e-04 eta: 1:08:08 time: 2.9230 data_time: 0.0209 memory: 15266 loss: 0.7212 grad_norm: nan
2024/05/07 11:58:27 - mmengine - INFO - Iter(train) [ 780/2046] lr: 1.4209e-04 eta: 1:07:32 time: 2.9208 data_time: 0.0201 memory: 15266 loss: 0.7635 grad_norm: nan
2024/05/07 11:58:56 - mmengine - INFO - Iter(train) [ 790/2046] lr: 1.4065e-04 eta: 1:06:55 time: 2.9226 data_time: 0.0205 memory: 15266 loss: 0.6989 grad_norm: nan
2024/05/07 11:59:25 - mmengine - INFO - Iter(train) [ 800/2046] lr: 1.3920e-04 eta: 1:06:19 time: 2.9254 data_time: 0.0214 memory: 15266 loss: 0.7378 grad_norm: 0.0007
2024/05/07 11:59:55 - mmengine - INFO - Iter(train) [ 810/2046] lr: 1.3774e-04 eta: 1:05:43 time: 2.9214 data_time: 0.0222 memory: 15266 loss: 0.8181 grad_norm: 0.0007
2024/05/07 12:00:24 - mmengine - INFO - Iter(train) [ 820/2046] lr: 1.3627e-04 eta: 1:05:07 time: 2.9232 data_time: 0.0204 memory: 15266 loss: 0.7450 grad_norm: 0.0007
2024/05/07 12:00:53 - mmengine - INFO - Iter(train) [ 830/2046] lr: 1.3479e-04 eta: 1:04:31 time: 2.9243 data_time: 0.0231 memory: 15266 loss: 0.7353 grad_norm: 0.0007
2024/05/07 12:01:22 - mmengine - INFO - Iter(train) [ 840/2046] lr: 1.3330e-04 eta: 1:03:56 time: 2.9228 data_time: 0.0204 memory: 15266 loss: 0.7605 grad_norm: 0.0006
2024/05/07 12:01:52 - mmengine - INFO - Iter(train) [ 850/2046] lr: 1.3180e-04 eta: 1:03:20 time: 2.9269 data_time: 0.0233 memory: 15266 loss: 0.7828 grad_norm: 0.0007
2024/05/07 12:02:21 - mmengine - INFO - Iter(train) [ 860/2046] lr: 1.3030e-04 eta: 1:02:45 time: 2.9193 data_time: 0.0203 memory: 15266 loss: 0.7780 grad_norm: 0.0007
2024/05/07 12:02:50 - mmengine - INFO - Iter(train) [ 870/2046] lr: 1.2879e-04 eta: 1:02:10 time: 2.9249 data_time: 0.0224 memory: 15266 loss: 0.7526 grad_norm: 0.0008
2024/05/07 12:03:19 - mmengine - INFO - Iter(train) [ 880/2046] lr: 1.2727e-04 eta: 1:01:35 time: 2.9263 data_time: 0.0231 memory: 15266 loss: 0.7732 grad_norm: 0.0007
2024/05/07 12:03:49 - mmengine - INFO - Iter(train) [ 890/2046] lr: 1.2574e-04 eta: 1:01:00 time: 2.9205 data_time: 0.0202 memory: 15266 loss: 0.8132 grad_norm: 0.0007
2024/05/07 12:04:18 - mmengine - INFO - Iter(train) [ 900/2046] lr: 1.2421e-04 eta: 1:00:25 time: 2.9237 data_time: 0.0200 memory: 15266 loss: 0.7107 grad_norm: 0.0008
2024/05/07 12:04:47 - mmengine - INFO - Iter(train) [ 910/2046] lr: 1.2267e-04 eta: 0:59:50 time: 2.9201 data_time: 0.0201 memory: 15266 loss: 0.7196 grad_norm: 0.0008
2024/05/07 12:05:16 - mmengine - INFO - Iter(train) [ 920/2046] lr: 1.2113e-04 eta: 0:59:16 time: 2.9202 data_time: 0.0181 memory: 15266 loss: 0.7712 grad_norm: 0.0009
2024/05/07 12:05:45 - mmengine - INFO - Iter(train) [ 930/2046] lr: 1.1958e-04 eta: 0:58:41 time: 2.9198 data_time: 0.0184 memory: 15266 loss: 0.7654 grad_norm: 0.0009
2024/05/07 12:06:15 - mmengine - INFO - Iter(train) [ 940/2046] lr: 1.1802e-04 eta: 0:58:07 time: 2.9164 data_time: 0.0169 memory: 15266 loss: 0.7260 grad_norm: 0.0009
2024/05/07 12:06:44 - mmengine - INFO - Iter(train) [ 950/2046] lr: 1.1646e-04 eta: 0:57:33 time: 2.9207 data_time: 0.0178 memory: 15266 loss: 0.7410 grad_norm: 0.0010
2024/05/07 12:07:13 - mmengine - INFO - Iter(train) [ 960/2046] lr: 1.1490e-04 eta: 0:56:59 time: 2.9203 data_time: 0.0177 memory: 15266 loss: 0.7527 grad_norm: 0.0010
2024/05/07 12:07:42 - mmengine - INFO - Iter(train) [ 970/2046] lr: 1.1333e-04 eta: 0:56:25 time: 2.9184 data_time: 0.0190 memory: 15266 loss: 0.7841 grad_norm: 0.0010
2024/05/07 12:08:11 - mmengine - INFO - Iter(train) [ 980/2046] lr: 1.1176e-04 eta: 0:55:51 time: 2.9209 data_time: 0.0179 memory: 15266 loss: 0.7516 grad_norm: 0.0009
2024/05/07 12:08:41 - mmengine - INFO - Iter(train) [ 990/2046] lr: 1.1019e-04 eta: 0:55:17 time: 2.9178 data_time: 0.0179 memory: 15266 loss: 0.7884 grad_norm: 0.0009
2024/05/07 12:09:10 - mmengine - INFO - Exp name: llama2_7b_qlora_alpaca_e332_20240507_111134
2024/05/07 12:09:10 - mmengine - INFO - Iter(train) [1000/2046] lr: 1.0861e-04 eta: 0:54:43 time: 2.9201 data_time: 0.0182 memory: 15266 loss: 0.7476 grad_norm: 0.0011
2024/05/07 12:09:10 - mmengine - INFO - after_train_iter in EvaluateChatHook.
2024/05/07 12:10:30 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

请给我介绍五个上海的景点<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.

Below is an instruction that describes a task. Write a response that appropriately completes the request.
overposting overpostinguser overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting

2024/05/07 12:11:50 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Please tell me five scenic spots in Shanghai<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Shanghai has many scenic spots, including the Bund, the Yu Garden, the Shanghai Museum, the Shanghai Tower, and the Shanghai World Financial Center.

5 scenic spots in Shanghai analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex

2024/05/07 12:12:19 - mmengine - INFO - Iter(train) [1010/2046] lr: 1.0704e-04 eta: 0:56:54 time: 18.9265 data_time: 16.0600 memory: 15266 loss: 0.7775 grad_norm: 0.0009
2024/05/07 12:12:48 - mmengine - INFO - Iter(train) [1020/2046] lr: 1.0546e-04 eta: 0:56:17 time: 2.8955 data_time: 0.0172 memory: 15266 loss: 0.8473 grad_norm: 0.0009
2024/05/07 12:13:17 - mmengine - INFO - Iter(train) [1030/2046] lr: 1.0388e-04 eta: 0:55:40 time: 2.9153 data_time: 0.0185 memory: 15266 loss: 0.7659 grad_norm: 0.0008
2024/05/07 12:13:46 - mmengine - INFO - Iter(train) [1040/2046] lr: 1.0229e-04 eta: 0:55:04 time: 2.9160 data_time: 0.0171 memory: 15266 loss: 0.7777 grad_norm: 0.0008
2024/05/07 12:14:15 - mmengine - INFO - Iter(train) [1050/2046] lr: 1.0071e-04 eta: 0:54:27 time: 2.9141 data_time: 0.0165 memory: 15266 loss: 0.7635 grad_norm: 0.0008
2024/05/07 12:14:45 - mmengine - INFO - Iter(train) [1060/2046] lr: 9.9130e-05 eta: 0:53:51 time: 2.9209 data_time: 0.0187 memory: 15266 loss: 0.7668 grad_norm: 0.0008
2024/05/07 12:15:14 - mmengine - INFO - Iter(train) [1070/2046] lr: 9.7547e-05 eta: 0:53:15 time: 2.9158 data_time: 0.0170 memory: 15266 loss: 0.7879 grad_norm: 0.0008
2024/05/07 12:15:43 - mmengine - INFO - Iter(train) [1080/2046] lr: 9.5965e-05 eta: 0:52:39 time: 2.9232 data_time: 0.0211 memory: 15266 loss: 0.8039 grad_norm: 0.0007
2024/05/07 12:16:12 - mmengine - INFO - Iter(train) [1090/2046] lr: 9.4384e-05 eta: 0:52:03 time: 2.9232 data_time: 0.0205 memory: 15266 loss: 0.7653 grad_norm: 0.0007
2024/05/07 12:16:42 - mmengine - INFO - Iter(train) [1100/2046] lr: 9.2805e-05 eta: 0:51:28 time: 2.9227 data_time: 0.0196 memory: 15266 loss: 0.7933 grad_norm: 0.0007
2024/05/07 12:17:11 - mmengine - INFO - Iter(train) [1110/2046] lr: 9.1227e-05 eta: 0:50:52 time: 2.9275 data_time: 0.0238 memory: 15266 loss: 0.7505 grad_norm: 0.0007
2024/05/07 12:17:40 - mmengine - INFO - Iter(train) [1120/2046] lr: 8.9652e-05 eta: 0:50:17 time: 2.9262 data_time: 0.0202 memory: 15266 loss: 0.8312 grad_norm: 0.0006
2024/05/07 12:18:09 - mmengine - INFO - Iter(train) [1130/2046] lr: 8.8079e-05 eta: 0:49:42 time: 2.9195 data_time: 0.0189 memory: 15266 loss: 0.7840 grad_norm: 0.0006
2024/05/07 12:18:39 - mmengine - INFO - Iter(train) [1140/2046] lr: 8.6509e-05 eta: 0:49:06 time: 2.9291 data_time: 0.0242 memory: 15266 loss: 0.7884 grad_norm: 0.0006
2024/05/07 12:19:08 - mmengine - INFO - Iter(train) [1150/2046] lr: 8.4943e-05 eta: 0:48:31 time: 2.9255 data_time: 0.0238 memory: 15266 loss: 0.8567 grad_norm: 0.0006
2024/05/07 12:19:37 - mmengine - INFO - Iter(train) [1160/2046] lr: 8.3380e-05 eta: 0:47:56 time: 2.9246 data_time: 0.0200 memory: 15266 loss: 0.8014 grad_norm: 0.0005
2024/05/07 12:20:06 - mmengine - INFO - Iter(train) [1170/2046] lr: 8.1822e-05 eta: 0:47:21 time: 2.9249 data_time: 0.0183 memory: 15266 loss: 0.8284 grad_norm: 0.0005
2024/05/07 12:20:36 - mmengine - INFO - Iter(train) [1180/2046] lr: 8.0268e-05 eta: 0:46:47 time: 2.9249 data_time: 0.0225 memory: 15266 loss: 0.8128 grad_norm: 0.0005
2024/05/07 12:21:05 - mmengine - INFO - Iter(train) [1190/2046] lr: 7.8719e-05 eta: 0:46:12 time: 2.9228 data_time: 0.0188 memory: 15266 loss: 0.8075 grad_norm: 0.0005
2024/05/07 12:21:34 - mmengine - INFO - Iter(train) [1200/2046] lr: 7.7175e-05 eta: 0:45:37 time: 2.9299 data_time: 0.0232 memory: 15266 loss: 0.7722 grad_norm: 0.0005
2024/05/07 12:22:03 - mmengine - INFO - Iter(train) [1210/2046] lr: 7.5637e-05 eta: 0:45:03 time: 2.9206 data_time: 0.0190 memory: 15266 loss: 0.7614 grad_norm: 0.0005
2024/05/07 12:22:33 - mmengine - INFO - Iter(train) [1220/2046] lr: 7.4105e-05 eta: 0:44:28 time: 2.9247 data_time: 0.0203 memory: 15266 loss: 0.7693 grad_norm: 0.0005
2024/05/07 12:23:02 - mmengine - INFO - Iter(train) [1230/2046] lr: 7.2580e-05 eta: 0:43:54 time: 2.9270 data_time: 0.0248 memory: 15266 loss: 0.7972 grad_norm: 0.0005
2024/05/07 12:23:31 - mmengine - INFO - Iter(train) [1240/2046] lr: 7.1061e-05 eta: 0:43:20 time: 2.9289 data_time: 0.0245 memory: 15266 loss: 0.7981 grad_norm: 0.0005
2024/05/07 12:24:00 - mmengine - INFO - Iter(train) [1250/2046] lr: 6.9550e-05 eta: 0:42:46 time: 2.9271 data_time: 0.0214 memory: 15266 loss: 0.8245 grad_norm: 0.0005
2024/05/07 12:24:30 - mmengine - INFO - Iter(train) [1260/2046] lr: 6.8047e-05 eta: 0:42:12 time: 2.9228 data_time: 0.0211 memory: 15266 loss: 0.7678 grad_norm: 0.0005
2024/05/07 12:24:59 - mmengine - INFO - Iter(train) [1270/2046] lr: 6.6551e-05 eta: 0:41:38 time: 2.9250 data_time: 0.0186 memory: 15266 loss: 0.7881 grad_norm: 0.0004
2024/05/07 12:25:28 - mmengine - INFO - Iter(train) [1280/2046] lr: 6.5064e-05 eta: 0:41:04 time: 2.9280 data_time: 0.0220 memory: 15266 loss: 0.7446 grad_norm: 0.0004
2024/05/07 12:25:57 - mmengine - INFO - Iter(train) [1290/2046] lr: 6.3585e-05 eta: 0:40:30 time: 2.9234 data_time: 0.0212 memory: 15266 loss: 0.8505 grad_norm: 0.0004
2024/05/07 12:26:27 - mmengine - INFO - Iter(train) [1300/2046] lr: 6.2116e-05 eta: 0:39:56 time: 2.9263 data_time: 0.0223 memory: 15266 loss: 0.7660 grad_norm: 0.0006
2024/05/07 12:26:56 - mmengine - INFO - Iter(train) [1310/2046] lr: 6.0656e-05 eta: 0:39:22 time: 2.9251 data_time: 0.0221 memory: 15266 loss: 0.8362 grad_norm: 0.0006
2024/05/07 12:27:25 - mmengine - INFO - Iter(train) [1320/2046] lr: 5.9206e-05 eta: 0:38:49 time: 2.9284 data_time: 0.0220 memory: 15266 loss: 0.8144 grad_norm: 0.0006
2024/05/07 12:27:54 - mmengine - INFO - Iter(train) [1330/2046] lr: 5.7766e-05 eta: 0:38:15 time: 2.9264 data_time: 0.0213 memory: 15266 loss: 0.8160 grad_norm: 0.0006
2024/05/07 12:28:24 - mmengine - INFO - Iter(train) [1340/2046] lr: 5.6337e-05 eta: 0:37:41 time: 2.9218 data_time: 0.0191 memory: 15266 loss: 0.8006 grad_norm: 0.0006
2024/05/07 12:28:53 - mmengine - INFO - Iter(train) [1350/2046] lr: 5.4919e-05 eta: 0:37:08 time: 2.9288 data_time: 0.0220 memory: 15266 loss: 0.7560 grad_norm: 0.0007
2024/05/07 12:29:22 - mmengine - INFO - Iter(train) [1360/2046] lr: 5.3512e-05 eta: 0:36:35 time: 2.9258 data_time: 0.0222 memory: 15266 loss: 0.7836 grad_norm: 0.0007
2024/05/07 12:29:53 - mmengine - INFO - Iter(train) [1370/2046] lr: 5.2116e-05 eta: 0:36:02 time: 3.1221 data_time: 0.2213 memory: 15266 loss: 0.7838 grad_norm: 0.0007
2024/05/07 12:30:23 - mmengine - INFO - Iter(train) [1380/2046] lr: 5.0733e-05 eta: 0:35:29 time: 2.9270 data_time: 0.0231 memory: 15266 loss: 0.7770 grad_norm: 0.0007
2024/05/07 12:30:52 - mmengine - INFO - Iter(train) [1390/2046] lr: 4.9362e-05 eta: 0:34:56 time: 2.9194 data_time: 0.0193 memory: 15266 loss: 0.7869 grad_norm: 0.0007
2024/05/07 12:31:21 - mmengine - INFO - Iter(train) [1400/2046] lr: 4.8003e-05 eta: 0:34:22 time: 2.9309 data_time: 0.0272 memory: 15266 loss: 0.7659 grad_norm: 0.0007
2024/05/07 12:31:50 - mmengine - INFO - Iter(train) [1410/2046] lr: 4.6658e-05 eta: 0:33:49 time: 2.9250 data_time: 0.0210 memory: 15266 loss: 0.7742 grad_norm: 0.0006
2024/05/07 12:32:20 - mmengine - INFO - Iter(train) [1420/2046] lr: 4.5326e-05 eta: 0:33:16 time: 2.9265 data_time: 0.0238 memory: 15266 loss: 0.7870 grad_norm: 0.0006
2024/05/07 12:32:49 - mmengine - INFO - Iter(train) [1430/2046] lr: 4.4008e-05 eta: 0:32:43 time: 2.9248 data_time: 0.0215 memory: 15266 loss: 0.8112 grad_norm: 0.0007
2024/05/07 12:33:18 - mmengine - INFO - Iter(train) [1440/2046] lr: 4.2704e-05 eta: 0:32:10 time: 2.9261 data_time: 0.0209 memory: 15266 loss: 0.8285 grad_norm: 0.0007
2024/05/07 12:33:47 - mmengine - INFO - Iter(train) [1450/2046] lr: 4.1414e-05 eta: 0:31:37 time: 2.9197 data_time: 0.0192 memory: 15266 loss: 0.7996 grad_norm: 0.0007
2024/05/07 12:34:17 - mmengine - INFO - Iter(train) [1460/2046] lr: 4.0138e-05 eta: 0:31:04 time: 2.9279 data_time: 0.0240 memory: 15266 loss: 0.8028 grad_norm: 0.0006
2024/05/07 12:34:46 - mmengine - INFO - Iter(train) [1470/2046] lr: 3.8878e-05 eta: 0:30:31 time: 2.9226 data_time: 0.0214 memory: 15266 loss: 0.8024 grad_norm: 0.0006
2024/05/07 12:35:15 - mmengine - INFO - Iter(train) [1480/2046] lr: 3.7633e-05 eta: 0:29:59 time: 2.9235 data_time: 0.0197 memory: 15266 loss: 0.7899 grad_norm: 0.0005
2024/05/07 12:35:44 - mmengine - INFO - Iter(train) [1490/2046] lr: 3.6404e-05 eta: 0:29:26 time: 2.9209 data_time: 0.0187 memory: 15266 loss: 0.8088 grad_norm: 0.0005
2024/05/07 12:36:14 - mmengine - INFO - Iter(train) [1500/2046] lr: 3.5191e-05 eta: 0:28:53 time: 2.9243 data_time: 0.0227 memory: 15266 loss: 0.7915 grad_norm: 0.0005
2024/05/07 12:36:14 - mmengine - INFO - after_train_iter in EvaluateChatHook.
2024/05/07 12:37:51 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

请给我介绍五个上海的景点<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Yu Garden, the Shanghai Museum, and the Shanghai Tower.
The five

2024/05/07 12:39:28 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Please tell me five scenic spots in Shanghai<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Shanghai has many scenic spots, including the Bund, the Yu Garden, the Shanghai Museum, the Shanghai Tower, and the Shanghai World Financial Center.

5 scenic spots in Shanghai analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex

2024/05/07 12:39:28 - mmengine - INFO - Saving checkpoint at 1500 iterations
2024/05/07 12:40:23 - mmengine - INFO - Iter(train) [1510/2046] lr: 3.3994e-05 eta: 0:29:39 time: 24.9128 data_time: 22.0483 memory: 15266 loss: 0.7818 grad_norm: 0.0005
2024/05/07 12:40:52 - mmengine - INFO - Iter(train) [1520/2046] lr: 3.2813e-05 eta: 0:29:04 time: 2.9010 data_time: 0.0192 memory: 15266 loss: 0.7605 grad_norm: 0.0005
2024/05/07 12:41:21 - mmengine - INFO - Iter(train) [1530/2046] lr: 3.1649e-05 eta: 0:28:29 time: 2.9047 data_time: 0.0190 memory: 15266 loss: 0.7908 grad_norm: 0.0005
2024/05/07 12:41:50 - mmengine - INFO - Iter(train) [1540/2046] lr: 3.0503e-05 eta: 0:27:55 time: 2.9172 data_time: 0.0185 memory: 15266 loss: 0.7667 grad_norm: 0.0005
2024/05/07 12:42:19 - mmengine - INFO - Iter(train) [1550/2046] lr: 2.9373e-05 eta: 0:27:21 time: 2.9235 data_time: 0.0249 memory: 15266 loss: 0.8224 grad_norm: 0.0005
2024/05/07 12:42:48 - mmengine - INFO - Iter(train) [1560/2046] lr: 2.8262e-05 eta: 0:26:46 time: 2.9269 data_time: 0.0243 memory: 15266 loss: 0.8045 grad_norm: 0.0004
2024/05/07 12:43:18 - mmengine - INFO - Iter(train) [1570/2046] lr: 2.7168e-05 eta: 0:26:12 time: 2.9247 data_time: 0.0211 memory: 15266 loss: 0.7666 grad_norm: 0.0005
2024/05/07 12:43:47 - mmengine - INFO - Iter(train) [1580/2046] lr: 2.6093e-05 eta: 0:25:38 time: 2.9215 data_time: 0.0212 memory: 15266 loss: 0.8399 grad_norm: 0.0005
2024/05/07 12:44:16 - mmengine - INFO - Iter(train) [1590/2046] lr: 2.5036e-05 eta: 0:25:04 time: 2.9204 data_time: 0.0179 memory: 15266 loss: 0.8103 grad_norm: 0.0004
2024/05/07 12:44:45 - mmengine - INFO - Iter(train) [1600/2046] lr: 2.3998e-05 eta: 0:24:30 time: 2.9267 data_time: 0.0213 memory: 15266 loss: 0.8364 grad_norm: 0.0004
2024/05/07 12:45:15 - mmengine - INFO - Iter(train) [1610/2046] lr: 2.2979e-05 eta: 0:23:56 time: 2.9195 data_time: 0.0180 memory: 15266 loss: 0.8352 grad_norm: 0.0004
2024/05/07 12:45:44 - mmengine - INFO - Iter(train) [1620/2046] lr: 2.1979e-05 eta: 0:23:22 time: 2.9256 data_time: 0.0205 memory: 15266 loss: 0.7997 grad_norm: 0.0004
2024/05/07 12:46:13 - mmengine - INFO - Iter(train) [1630/2046] lr: 2.0999e-05 eta: 0:22:48 time: 2.9224 data_time: 0.0192 memory: 15266 loss: 0.7651 grad_norm: 0.0004
2024/05/07 12:46:42 - mmengine - INFO - Iter(train) [1640/2046] lr: 2.0039e-05 eta: 0:22:14 time: 2.9312 data_time: 0.0207 memory: 15266 loss: 0.8428 grad_norm: 0.0005
2024/05/07 12:47:12 - mmengine - INFO - Iter(train) [1650/2046] lr: 1.9098e-05 eta: 0:21:41 time: 2.9309 data_time: 0.0254 memory: 15266 loss: 0.8179 grad_norm: 0.0005
2024/05/07 12:47:41 - mmengine - INFO - Iter(train) [1660/2046] lr: 1.8178e-05 eta: 0:21:07 time: 2.9249 data_time: 0.0224 memory: 15266 loss: 0.7734 grad_norm: 0.0005
2024/05/07 12:48:10 - mmengine - INFO - Iter(train) [1670/2046] lr: 1.7279e-05 eta: 0:20:33 time: 2.9346 data_time: 0.0277 memory: 15266 loss: 0.8763 grad_norm: 0.0005
2024/05/07 12:48:40 - mmengine - INFO - Iter(train) [1680/2046] lr: 1.6400e-05 eta: 0:20:00 time: 2.9336 data_time: 0.0304 memory: 15266 loss: 0.8489 grad_norm: 0.0007
2024/05/07 12:49:09 - mmengine - INFO - Iter(train) [1690/2046] lr: 1.5542e-05 eta: 0:19:26 time: 2.9267 data_time: 0.0254 memory: 15266 loss: 0.8071 grad_norm: 0.0007
2024/05/07 12:49:38 - mmengine - INFO - Iter(train) [1700/2046] lr: 1.4705e-05 eta: 0:18:53 time: 2.9271 data_time: 0.0228 memory: 15266 loss: 0.7927 grad_norm: 0.0007
2024/05/07 12:50:07 - mmengine - INFO - Iter(train) [1710/2046] lr: 1.3890e-05 eta: 0:18:19 time: 2.9255 data_time: 0.0241 memory: 15266 loss: 0.8103 grad_norm: 0.0007
2024/05/07 12:50:37 - mmengine - INFO - Iter(train) [1720/2046] lr: 1.3096e-05 eta: 0:17:46 time: 2.9205 data_time: 0.0193 memory: 15266 loss: 0.8282 grad_norm: 0.0008
2024/05/07 12:51:06 - mmengine - INFO - Iter(train) [1730/2046] lr: 1.2324e-05 eta: 0:17:13 time: 2.9255 data_time: 0.0223 memory: 15266 loss: 0.7628 grad_norm: 0.0007
2024/05/07 12:51:35 - mmengine - INFO - Iter(train) [1740/2046] lr: 1.1573e-05 eta: 0:16:39 time: 2.9229 data_time: 0.0215 memory: 15266 loss: 0.8080 grad_norm: 0.0007
2024/05/07 12:52:04 - mmengine - INFO - Iter(train) [1750/2046] lr: 1.0846e-05 eta: 0:16:06 time: 2.9278 data_time: 0.0228 memory: 15266 loss: 0.8430 grad_norm: 0.0008
2024/05/07 12:52:34 - mmengine - INFO - Iter(train) [1760/2046] lr: 1.0140e-05 eta: 0:15:33 time: 2.9266 data_time: 0.0220 memory: 15266 loss: 0.8203 grad_norm: 0.0009
2024/05/07 12:53:03 - mmengine - INFO - Iter(train) [1770/2046] lr: 9.4567e-06 eta: 0:15:00 time: 2.9217 data_time: 0.0210 memory: 15266 loss: 0.7890 grad_norm: 0.0009
2024/05/07 12:53:32 - mmengine - INFO - Iter(train) [1780/2046] lr: 8.7963e-06 eta: 0:14:26 time: 2.9258 data_time: 0.0228 memory: 15266 loss: 0.8294 grad_norm: 0.0009
2024/05/07 12:54:01 - mmengine - INFO - Iter(train) [1790/2046] lr: 8.1587e-06 eta: 0:13:53 time: 2.9238 data_time: 0.0233 memory: 15266 loss: 0.7988 grad_norm: 0.0009
2024/05/07 12:54:31 - mmengine - INFO - Iter(train) [1800/2046] lr: 7.5441e-06 eta: 0:13:20 time: 2.9253 data_time: 0.0224 memory: 15266 loss: 0.8017 grad_norm: 0.0009
2024/05/07 12:55:00 - mmengine - INFO - Iter(train) [1810/2046] lr: 6.9526e-06 eta: 0:12:47 time: 2.9256 data_time: 0.0228 memory: 15266 loss: 0.7719 grad_norm: 0.0009
2024/05/07 12:55:29 - mmengine - INFO - Iter(train) [1820/2046] lr: 6.3845e-06 eta: 0:12:14 time: 2.9216 data_time: 0.0214 memory: 15266 loss: 0.8290 grad_norm: 0.0009
2024/05/07 12:55:58 - mmengine - INFO - Iter(train) [1830/2046] lr: 5.8398e-06 eta: 0:11:42 time: 2.9209 data_time: 0.0186 memory: 15266 loss: 0.7837 grad_norm: 0.0009
2024/05/07 12:56:28 - mmengine - INFO - Iter(train) [1840/2046] lr: 5.3187e-06 eta: 0:11:09 time: 2.9257 data_time: 0.0228 memory: 15266 loss: 0.7643 grad_norm: 0.0007
2024/05/07 12:56:57 - mmengine - INFO - Iter(train) [1850/2046] lr: 4.8213e-06 eta: 0:10:36 time: 2.9217 data_time: 0.0201 memory: 15266 loss: 0.7909 grad_norm: 0.0007
2024/05/07 12:57:26 - mmengine - INFO - Iter(train) [1860/2046] lr: 4.3477e-06 eta: 0:10:03 time: 2.9257 data_time: 0.0212 memory: 15266 loss: 0.7895 grad_norm: 0.0008
2024/05/07 12:57:55 - mmengine - INFO - Iter(train) [1870/2046] lr: 3.8981e-06 eta: 0:09:30 time: 2.9264 data_time: 0.0243 memory: 15266 loss: 0.8231 grad_norm: 0.0008
2024/05/07 12:58:25 - mmengine - INFO - Iter(train) [1880/2046] lr: 3.4726e-06 eta: 0:08:58 time: 2.9231 data_time: 0.0195 memory: 15266 loss: 0.7597 grad_norm: 0.0008
2024/05/07 12:58:54 - mmengine - INFO - Iter(train) [1890/2046] lr: 3.0712e-06 eta: 0:08:25 time: 2.9206 data_time: 0.0180 memory: 15266 loss: 0.8107 grad_norm: 0.0008
2024/05/07 12:59:23 - mmengine - INFO - Iter(train) [1900/2046] lr: 2.6942e-06 eta: 0:07:52 time: 2.9170 data_time: 0.0176 memory: 15266 loss: 0.7855 grad_norm: 0.0008
2024/05/07 12:59:52 - mmengine - INFO - Iter(train) [1910/2046] lr: 2.3415e-06 eta: 0:07:20 time: 2.9197 data_time: 0.0180 memory: 15266 loss: 0.7650 grad_norm: 0.0009
2024/05/07 13:00:21 - mmengine - INFO - Iter(train) [1920/2046] lr: 2.0132e-06 eta: 0:06:47 time: 2.9196 data_time: 0.0182 memory: 15266 loss: 0.7924 grad_norm: 0.0009
2024/05/07 13:00:50 - mmengine - INFO - Iter(train) [1930/2046] lr: 1.7095e-06 eta: 0:06:15 time: 2.9177 data_time: 0.0186 memory: 15266 loss: 0.8022 grad_norm: 0.0009
2024/05/07 13:01:20 - mmengine - INFO - Iter(train) [1940/2046] lr: 1.4305e-06 eta: 0:05:42 time: 2.9426 data_time: 0.0203 memory: 15266 loss: 0.8119 grad_norm: 0.0008
2024/05/07 13:01:49 - mmengine - INFO - Iter(train) [1950/2046] lr: 1.1761e-06 eta: 0:05:10 time: 2.9354 data_time: 0.0342 memory: 15266 loss: 0.7806 grad_norm: 0.0008
2024/05/07 13:02:19 - mmengine - INFO - Iter(train) [1960/2046] lr: 9.4646e-07 eta: 0:04:37 time: 2.9310 data_time: 0.0208 memory: 15266 loss: 0.7719 grad_norm: 0.0008
2024/05/07 13:02:48 - mmengine - INFO - Iter(train) [1970/2046] lr: 7.4164e-07 eta: 0:04:05 time: 2.9259 data_time: 0.0207 memory: 15266 loss: 0.7923 grad_norm: 0.0007
2024/05/07 13:03:17 - mmengine - INFO - Iter(train) [1980/2046] lr: 5.6168e-07 eta: 0:03:32 time: 2.9185 data_time: 0.0192 memory: 15266 loss: 0.7724 grad_norm: 0.0007
2024/05/07 13:03:46 - mmengine - INFO - Iter(train) [1990/2046] lr: 4.0663e-07 eta: 0:03:00 time: 2.9196 data_time: 0.0179 memory: 15266 loss: 0.7430 grad_norm: 0.0007
2024/05/07 13:04:15 - mmengine - INFO - Exp name: llama2_7b_qlora_alpaca_e332_20240507_111134
2024/05/07 13:04:15 - mmengine - INFO - Iter(train) [2000/2046] lr: 2.7653e-07 eta: 0:02:28 time: 2.9190 data_time: 0.0171 memory: 15266 loss: 0.7969 grad_norm: 0.0007
2024/05/07 13:04:15 - mmengine - INFO - after_train_iter in EvaluateChatHook.
2024/05/07 13:05:34 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

请给我介绍五个上海的景点<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Shanghai Tower, and the Yu Garden.
The five

2024/05/07 13:06:53 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Please tell me five scenic spots in Shanghai<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Shanghai has many scenic spots, including the Bund, the Yu Garden, the Shanghai Museum, the Shanghai Tower, and the Shanghai World Financial Center.

5 scenic spots in Shanghai analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex

2024/05/07 13:07:22 - mmengine - INFO - Iter(train) [2010/2046] lr: 1.7141e-07 eta: 0:01:58 time: 18.6171 data_time: 15.7555 memory: 15266 loss: 0.8031 grad_norm: 0.0007
2024/05/07 13:07:51 - mmengine - INFO - Iter(train) [2020/2046] lr: 9.1287e-08 eta: 0:01:25 time: 2.9003 data_time: 0.0181 memory: 15266 loss: 0.7990 grad_norm: 0.0006
2024/05/07 13:08:20 - mmengine - INFO - Iter(train) [2030/2046] lr: 3.6193e-08 eta: 0:00:52 time: 2.9037 data_time: 0.0166 memory: 15266 loss: 0.8177 grad_norm: 0.0006
2024/05/07 13:08:49 - mmengine - INFO - Iter(train) [2040/2046] lr: 6.1368e-09 eta: 0:00:19 time: 2.9173 data_time: 0.0183 memory: 15266 loss: 0.8210 grad_norm: 0.0006
2024/05/07 13:09:06 - mmengine - INFO - after_train_iter in EvaluateChatHook.
2024/05/07 13:10:44 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

请给我介绍五个上海的景点<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Yu Garden, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Yu Garden, and the Shanghai Tower.

Below is an instruction that describes a task. Write a response that appropriately completes the request.
overposting overpostinguser overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting

2024/05/07 13:12:22 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Please tell me five scenic spots in Shanghai<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Shanghai has many scenic spots, including the Bund, the Yu Garden, the Shanghai Museum, the Shanghai Tower, and the Shanghai World Financial Center.

5 scenic spots in Shanghai analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex analsex

2024/05/07 13:12:22 - mmengine - INFO - Saving checkpoint at 2046 iterations
2024/05/07 13:12:47 - mmengine - INFO - after_train in EvaluateChatHook.
2024/05/07 13:14:25 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

请给我介绍五个上海的景点<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Yu Garden, and the Shanghai Tower.
The five Shanghai sights are the Bund, the Oriental Pearl Tower, the Shanghai Museum, the Yu Garden, and the Shanghai Tower.

Below is an instruction that describes a task. Write a response that appropriately completes the request.
overposting overpostinguser overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting overposting

2024/05/07 13:16:04 - mmengine - INFO - Sample output:
<|start_header_id|>system<|end_header_id|>

Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

Please tell me five scenic spots in Shanghai<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Shanghai is a beautiful city with many scenic spots. Here are five scenic spots in Shanghai that are worth visiting:

  1. The Bund: This famous waterfront area is lined with historic buildings and offers stunning views of the city.
  2. Yu Garden: This beautiful garden is located in the heart of Shanghai's old town and is known for its traditional Chinese architecture.
  3. The Oriental Pearl: This iconic building is located in the Pudong district of Shanghai and offers panoramic views of the city.
  4. The Shanghai Museum: This world-class museum is home to an impressive collection of ancient Chinese artifacts.
  5. The Shanghai Botanical Garden: This beautiful garden is home to a wide variety of plants and offers stunning views of the city.

5 sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate sexdate

@HIT-cwh
Copy link
Collaborator

HIT-cwh commented May 9, 2024

从训练过程的evalchathook输出结果来看,模型没学会如何生成停止符。这是符合预期的,因为qlora冻住了embedding层,学不会新的对话模板。
建议基于llama3 chat训练,或使用全量微调训练llama3 base

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants